blob: e575cf6deb83562ad49c748ac71ed4201693f47e [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Sheri Zhang1a378102020-04-30 12:59:39 +01002 * Copyright (c) 2017-2020 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010049 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
50 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
51 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000053 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010057 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000058 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100249 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100257 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100405 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100558 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000736#if N0 > 2
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000738#endif // N0 > 2
739#if N0 > 3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000741#endif // N0 > 3
742#if N0 > 4
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000747#endif // N0 > 4
748#if N0 > 8
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
1014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001025 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001026 * The activation function is performed after the bias addition
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001027 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1028 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1029 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1030 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1031 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1032 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1033 *
Sheri Zhang1a378102020-04-30 12:59:39 +01001034 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
1035 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001036 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001037 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001038 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001039 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001040 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1041 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1042 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1043 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1044 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1045 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001046 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1047 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1048 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1049 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1050 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1051 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001052 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1053 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1054 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1055 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1056 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1057 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Sheri Zhang1a378102020-04-30 12:59:39 +01001058 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001059 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001060 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001061 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1062 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1063 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001064 */
1065__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1066 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001067#if defined(BETA)
1068 IMAGE_DECLARATION(bias),
1069#endif // defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001070 IMAGE_DECLARATION(dst),
1071 uint lhs_stride_z,
1072 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001073#if defined(BETA)
1074 uint bias_stride_z,
1075#endif //defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001076 uint dst_stride_z
1077#if defined(REINTERPRET_INPUT_AS_3D)
1078 ,
1079 uint lhs_cross_plane_pad
1080#endif // REINTERPRET_INPUT_AS_3D
1081#if defined(REINTERPRET_OUTPUT_AS_3D)
1082 ,
1083 uint dst_cross_plane_pad
1084#endif // REINTERPRET_OUTPUT_AS_3D
1085 )
1086{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001087 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001088#define RHS_BLOCK_SIZE ((K0) * (N0))
1089
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001090 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001091#if defined(RHS_INTERLEAVE)
1092#define RHS_OFFSET_X (K0)
1093#define RHS_STEP_X ((K0) * (H0))
1094#define RHS_STEP_LOOP (1)
1095#else // defined(RHS_INTERLEAVE)
1096#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1097#define RHS_STEP_X (K0)
1098#define RHS_STEP_LOOP (H0)
1099#endif // defined(RHS_INTERLEAVE)
1100
1101 uint x = get_global_id(0);
1102 uint y = get_global_id(1);
1103 uint z = get_global_id(2);
1104
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001105#if defined(DUMMY_WORK_ITEMS)
1106 if((x * N0 >= N) || (y * M0 >= M))
1107 {
1108 return;
1109 }
1110#endif // defined(DUMMY_WORK_ITEMS)
1111
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001112 // Compute LHS matrix address
1113 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1114
Sheri Zhang1a378102020-04-30 12:59:39 +01001115 // Compute RHS reshaped matrix address
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001116 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1117
1118#if defined(MATRIX_B_DEPTH)
1119 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1120 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1121#else // defined(MATRIX_B_DEPTH)
1122 rhs_offset += z * rhs_stride_z;
1123#endif // defined(MATRIX_B_DEPTH)
1124
Usama Arif0681e3b2019-04-25 14:28:07 +01001125 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001126 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001127
1128#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001129 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1130 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001131
1132 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1133 // multiply lhs_stride_z by DEPTH_GEMM3D
1134 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1135
1136#else // defined(REINTERPRET_INPUT_AS_3D)
1137
1138 // Add offset for batched GEMM
1139 lhs_offset += z * lhs_stride_z;
1140
1141#endif // defined(REINTERPRET_INPUT_AS_3D)
1142
1143 // Initialize the accumulators
1144 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1145
1146 int i = 0;
1147 for(; i <= (K - K0); i += K0)
1148 {
1149 // Supported cases (M0, K0):
1150 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1151 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1152 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1153 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1154 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1155 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1156 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1157 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1158 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001159 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001160
Sheri Zhang1a378102020-04-30 12:59:39 +01001161 // Load values from RHS reshaped matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001162 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001163
1164 // Accumulate
1165 ARM_DOT_K0XN0(K0, a0, b, c0);
1166#if M0 > 1
1167 ARM_DOT_K0XN0(K0, a1, b, c1);
1168#endif // M0 > 1
1169#if M0 > 2
1170 ARM_DOT_K0XN0(K0, a2, b, c2);
1171#endif // M0 > 2
1172#if M0 > 3
1173 ARM_DOT_K0XN0(K0, a3, b, c3);
1174#endif // M0 > 3
1175#if M0 > 4
1176 ARM_DOT_K0XN0(K0, a4, b, c4);
1177#endif // M0 > 4
1178#if M0 > 5
1179 ARM_DOT_K0XN0(K0, a5, b, c5);
1180#endif // M0 > 5
1181#if M0 > 6
1182 ARM_DOT_K0XN0(K0, a6, b, c6);
1183#endif // M0 > 6
1184#if M0 > 7
1185 ARM_DOT_K0XN0(K0, a7, b, c7);
1186#endif // M0 > 7
1187
1188 lhs_offset += K0 * sizeof(DATA_TYPE);
1189 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1190 }
1191
1192 // Left-over accumulations
1193 for(; i < K; ++i)
1194 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001195 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001196 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001197
Sheri Zhang1a378102020-04-30 12:59:39 +01001198 // Load values from RHS reshaped matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001199 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001200
1201 // Accumulate
1202 ARM_DOT_K0XN0(1, a0, b, c0);
1203#if M0 > 1
1204 ARM_DOT_K0XN0(1, a1, b, c1);
1205#endif // M0 > 1
1206#if M0 > 2
1207 ARM_DOT_K0XN0(1, a2, b, c2);
1208#endif // M0 > 2
1209#if M0 > 3
1210 ARM_DOT_K0XN0(1, a3, b, c3);
1211#endif // M0 > 3
1212#if M0 > 4
1213 ARM_DOT_K0XN0(1, a4, b, c4);
1214#endif // M0 > 4
1215#if M0 > 5
1216 ARM_DOT_K0XN0(1, a5, b, c5);
1217#endif // M0 > 5
1218#if M0 > 6
1219 ARM_DOT_K0XN0(1, a6, b, c6);
1220#endif // M0 > 6
1221#if M0 > 7
1222 ARM_DOT_K0XN0(1, a7, b, c7);
1223#endif // M0 > 7
1224
1225 lhs_offset += sizeof(DATA_TYPE);
1226 rhs_offset += sizeof(DATA_TYPE);
1227 }
1228
1229 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1230
1231 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1232
1233#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001234
1235 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001236 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001237
1238 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1239 // multiply dst_stride_z by DEPTH_GEMM3D
1240 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1241
1242#else // defined(REINTERPRET_OUTPUT_AS_3D)
1243
1244 // Add offset for batched GEMM
1245 dst_addr += z * dst_stride_z;
1246
1247#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1248
1249 // Multiply by the weight of matrix-matrix product and store the result
1250#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001251 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001252#endif // defined(ALPHA)
1253
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001254 // Add beta*bias
1255#if defined(BETA)
1256#if defined(BROADCAST_BIAS)
1257 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1258
1259 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1260
1261#ifndef UNIT_BETA
1262 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1263#endif // UNIT_BIAS
1264
1265 // c = c + bias[broadcasted]
1266 ADD_BLOCK_BROADCAST(M0, c, bias0);
1267
1268#else // defined(BROADCAST_BIAS)
1269 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1270 2) * bias_stride_z;
1271
1272 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1273
1274#ifndef UNIT_BETA
1275 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1276#endif // UNIT_BIAS
1277
1278 // c = c + bias
1279 ADD_BLOCK(M0, c, bias);
1280
1281#endif // defined(BROADCAST_BIAS)
1282#endif // defined(BETA)
1283
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001284#if defined(ACTIVATION_TYPE)
1285 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1286#endif // defined(ACTIVATION_TYPE)
1287
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001288 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001289 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001290
1291#undef RHS_BLOCK_SIZE
1292#undef RHS_OFFSET_X
1293#undef RHS_STEP_X
1294}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001295
1296#define VFMA(a, b, c) \
1297 ({ \
1298 c = fma(a, b, c); \
1299 })
1300
1301#if M0 == 1
1302#define LD_RHS_VFMA_M0xN0(i, a, c) \
1303 ({ \
1304 VEC_DATA_TYPE(DATA_TYPE, N0) \
1305 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1306 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1307 })
1308#elif M0 == 2 // M0 == 2
1309#define LD_RHS_VFMA_M0xN0(i, a, c) \
1310 ({ \
1311 VEC_DATA_TYPE(DATA_TYPE, N0) \
1312 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1313 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1314 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1315 })
1316#elif M0 == 3 // M0 == 3
1317#define LD_RHS_VFMA_M0xN0(i, a, c) \
1318 ({ \
1319 VEC_DATA_TYPE(DATA_TYPE, N0) \
1320 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1321 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1322 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1323 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1324 })
1325#elif M0 == 4 // M0 == 4
1326#define LD_RHS_VFMA_M0xN0(i, a, c) \
1327 ({ \
1328 VEC_DATA_TYPE(DATA_TYPE, N0) \
1329 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1330 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1331 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1332 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1333 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1334 })
1335#elif M0 == 5 // M0 == 5
1336#define LD_RHS_VFMA_M0xN0(i, a, c) \
1337 ({ \
1338 VEC_DATA_TYPE(DATA_TYPE, N0) \
1339 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1340 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1341 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1342 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1343 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1344 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1345 })
1346#elif M0 == 6 // M0 == 6
1347#define LD_RHS_VFMA_M0xN0(i, a, c) \
1348 ({ \
1349 VEC_DATA_TYPE(DATA_TYPE, N0) \
1350 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1351 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1352 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1353 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1354 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1355 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1356 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1357 })
1358#elif M0 == 7 // M0 == 7
1359#define LD_RHS_VFMA_M0xN0(i, a, c) \
1360 ({ \
1361 VEC_DATA_TYPE(DATA_TYPE, N0) \
1362 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1363 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1364 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1365 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1366 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1367 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1368 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1369 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1370 })
1371#elif M0 == 8 // M0 == 8
1372#define LD_RHS_VFMA_M0xN0(i, a, c) \
1373 ({ \
1374 VEC_DATA_TYPE(DATA_TYPE, N0) \
1375 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1376 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1377 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1378 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1379 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1380 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1381 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1382 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1383 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1384 })
1385#else // M0 not supported
1386#error "M0 not supported"
1387#endif // M0 not supported
1388
1389/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1390 * The LHS matrix is NOT reshaped
1391 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1392 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001393 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001394 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90).
1395 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1396 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1397 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001398 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1399 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1400 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1401 * - N0 = 2, 3, 4, 8, 16
1402 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001403 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001405 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001406 * The activation function is performed after the bias addition
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001407 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1408 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1409 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1410 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1411 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1412 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1413 *
Sheri Zhang1a378102020-04-30 12:59:39 +01001414 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
1415 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001416 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001417 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001418 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001419 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001420 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1421 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1422 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1423 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1424 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1425 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001426 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1427 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001428 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001429 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001430 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1431 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1432 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1433 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1434 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1435 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1436 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1437 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Sheri Zhang1a378102020-04-30 12:59:39 +01001438 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001439 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001440 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001441 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1442 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1443 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001444 */
1445__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1446 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001447#if defined(BETA)
1448 IMAGE_DECLARATION(bias),
1449#endif // defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001450 IMAGE_DECLARATION(dst),
1451 uint lhs_stride_z,
1452 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001453#if defined(BETA)
1454 uint bias_stride_z,
1455#endif //defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001456 uint dst_stride_z
1457#if defined(REINTERPRET_INPUT_AS_3D)
1458 ,
1459 uint lhs_cross_plane_pad
1460#endif // REINTERPRET_INPUT_AS_3D
1461#if defined(REINTERPRET_OUTPUT_AS_3D)
1462 ,
1463 uint dst_cross_plane_pad
1464#endif // REINTERPRET_OUTPUT_AS_3D
1465 )
1466{
1467 // Block size
1468#define RHS_BLOCK_SIZE ((K0) * (N0))
1469
1470 // RHS offset and step X
1471#if defined(RHS_INTERLEAVE)
1472#define RHS_OFFSET_X (N0)
1473#define RHS_STEP_X ((N0) * (H0))
1474#define RHS_STEP_LOOP (1)
1475#else // defined(RHS_INTERLEAVE)
1476#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1477#define RHS_STEP_X (N0)
1478#define RHS_STEP_LOOP (H0)
1479#endif // defined(RHS_INTERLEAVE)
1480
1481 uint x = get_global_id(0);
1482 uint y = get_global_id(1);
1483 uint z = get_global_id(2);
1484
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001485#if defined(DUMMY_WORK_ITEMS)
1486 if((x * N0 >= N) || (y * M0 >= M))
1487 {
1488 return;
1489 }
1490#endif // defined(DUMMY_WORK_ITEMS)
1491
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001492 // Compute LHS matrix address
1493 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1494
Sheri Zhang1a378102020-04-30 12:59:39 +01001495 // Compute RHS reshaped matrix address
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001496 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1497
1498#if defined(MATRIX_B_DEPTH)
1499 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1500 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1501#else // defined(MATRIX_B_DEPTH)
1502 rhs_offset += z * rhs_stride_z;
1503#endif // defined(MATRIX_B_DEPTH)
1504
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001505 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1506 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001507
1508#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001509
1510 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001511 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001512
1513 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1514 // multiply lhs_stride_z by DEPTH_GEMM3D
1515 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1516
1517#else // defined(REINTERPRET_INPUT_AS_3D)
1518
1519 // Add offset for batched GEMM
1520 lhs_offset += z * lhs_stride_z;
1521
1522#endif // defined(REINTERPRET_INPUT_AS_3D)
1523
1524 // Initialize the accumulators
1525 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1526
1527 int i = 0;
1528 for(; i <= (K - K0); i += K0)
1529 {
1530 // Supported cases (M0, K0):
1531 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1532 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1533 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1534 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1535 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1536 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1537 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1538 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1539 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001540 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001541
1542 LD_RHS_VFMA_M0xN0(0, a, c);
1543 LD_RHS_VFMA_M0xN0(1, a, c);
1544#if K0 > 2
1545 LD_RHS_VFMA_M0xN0(2, a, c);
1546#endif // K0 > 2
1547#if K0 > 3
1548 LD_RHS_VFMA_M0xN0(3, a, c);
1549#endif // K0 > 3
1550#if K0 > 4
1551 LD_RHS_VFMA_M0xN0(4, a, c);
1552 LD_RHS_VFMA_M0xN0(5, a, c);
1553 LD_RHS_VFMA_M0xN0(6, a, c);
1554 LD_RHS_VFMA_M0xN0(7, a, c);
1555#endif // K0 > 4
1556#if K0 > 8
1557 LD_RHS_VFMA_M0xN0(8, a, c);
1558 LD_RHS_VFMA_M0xN0(9, a, c);
1559 LD_RHS_VFMA_M0xN0(A, a, c);
1560 LD_RHS_VFMA_M0xN0(B, a, c);
1561 LD_RHS_VFMA_M0xN0(C, a, c);
1562 LD_RHS_VFMA_M0xN0(D, a, c);
1563 LD_RHS_VFMA_M0xN0(E, a, c);
1564 LD_RHS_VFMA_M0xN0(F, a, c);
1565#endif // K0 > 8
1566
1567 lhs_offset += K0 * sizeof(DATA_TYPE);
1568 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1569 }
1570
1571 // Left-over accumulations
1572 for(; i < K; ++i)
1573 {
1574 // Load values from LHS matrix
1575 VEC_DATA_TYPE(DATA_TYPE, 2)
1576 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1577#if M0 > 1
1578 VEC_DATA_TYPE(DATA_TYPE, 2)
1579 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1580#endif // M0 > 1
1581#if M0 > 2
1582 VEC_DATA_TYPE(DATA_TYPE, 2)
1583 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1584#endif // M0 > 2
1585#if M0 > 3
1586 VEC_DATA_TYPE(DATA_TYPE, 2)
1587 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1588#endif // M0 > 3
1589#if M0 > 4
1590 VEC_DATA_TYPE(DATA_TYPE, 2)
1591 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1592#endif // M0 > 4
1593#if M0 > 5
1594 VEC_DATA_TYPE(DATA_TYPE, 2)
1595 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1596#endif // M0 > 5
1597#if M0 > 6
1598 VEC_DATA_TYPE(DATA_TYPE, 2)
1599 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1600#endif // M0 > 6
1601#if M0 > 7
1602 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001603 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001604#endif // M0 > 7
1605
1606 LD_RHS_VFMA_M0xN0(0, a, c);
1607
1608 lhs_offset += sizeof(DATA_TYPE);
1609 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1610 }
1611
1612 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1613
1614 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1615
1616#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001617 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001618 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001619
1620 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1621 // multiply dst_stride_z by DEPTH_GEMM3D
1622 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1623
1624#else // defined(REINTERPRET_OUTPUT_AS_3D)
1625
1626 // Add offset for batched GEMM
1627 dst_addr += z * dst_stride_z;
1628
1629#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1630
1631 // Multiply by the weight of matrix-matrix product and store the result
1632#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001633 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001634#endif // defined(ALPHA)
1635
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001636 // Add beta*bias
1637#if defined(BETA)
1638#if defined(BROADCAST_BIAS)
1639 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1640
1641 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1642
1643#ifndef UNIT_BETA
1644 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1645#endif // UNIT_BIAS
1646
1647 // c = c + bias[broadcasted]
1648 ADD_BLOCK_BROADCAST(M0, c, bias0);
1649
1650#else // defined(BROADCAST_BIAS)
1651 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1652 2) * bias_stride_z;
1653
1654 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1655
1656#ifndef UNIT_BETA
1657 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1658#endif // UNIT_BIAS
1659
1660 // c = c + bias
1661 ADD_BLOCK(M0, c, bias);
1662
1663#endif // defined(BROADCAST_BIAS)
1664#endif // defined(BETA)
1665
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001666#if defined(ACTIVATION_TYPE)
1667 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1668#endif // defined(ACTIVATION_TYPE)
1669
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001670 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001671 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001672
1673#undef RHS_BLOCK_SIZE
1674#undef RHS_OFFSET_X
1675#undef RHS_STEP_X
1676}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001677#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001678
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001679#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001680
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001681#if defined(MIXED_PRECISION)
1682#if K0 == 2
1683#define ARM_DOT_K0(a, b, c) \
1684 ({ \
1685 c += a.s0 * b.s0; \
1686 c += a.s1 * b.s1; \
1687 })
1688#elif K0 == 3 // K0 == 3
1689#define ARM_DOT_K0(a, b, c) \
1690 ({ \
1691 c += a.s0 * b.s0; \
1692 c += a.s1 * b.s1; \
1693 c += a.s2 * b.s2; \
1694 })
1695#elif K0 == 4 // K0 == 4
1696#define ARM_DOT_K0(a, b, c) \
1697 ({ \
1698 c += a.s0 * b.s0; \
1699 c += a.s1 * b.s1; \
1700 c += a.s2 * b.s2; \
1701 c += a.s3 * b.s3; \
1702 })
1703#elif K0 == 8 // K0 == 8
1704#define ARM_DOT_K0(a, b, c) \
1705 ({ \
1706 c += a.s0 * b.s0; \
1707 c += a.s1 * b.s1; \
1708 c += a.s2 * b.s2; \
1709 c += a.s3 * b.s3; \
1710 c += a.s4 * b.s4; \
1711 c += a.s5 * b.s5; \
1712 c += a.s6 * b.s6; \
1713 c += a.s7 * b.s7; \
1714 })
1715#elif K0 == 16 // K0 == 16
1716#define ARM_DOT_K0(a, b, c) \
1717 ({ \
1718 c += a.s0 * b.s0; \
1719 c += a.s1 * b.s1; \
1720 c += a.s2 * b.s2; \
1721 c += a.s3 * b.s3; \
1722 c += a.s4 * b.s4; \
1723 c += a.s5 * b.s5; \
1724 c += a.s6 * b.s6; \
1725 c += a.s7 * b.s7; \
1726 c += a.s8 * b.s8; \
1727 c += a.s9 * b.s9; \
1728 c += a.sA * b.sA; \
1729 c += a.sB * b.sB; \
1730 c += a.sC * b.sC; \
1731 c += a.sD * b.sD; \
1732 c += a.sE * b.sE; \
1733 c += a.sF * b.sF; \
1734 })
1735#else // K0 not supported
1736#error "K0 value not supported"
1737#endif // K0 conditions
1738#else // defined(MIXED_PRECISION)
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001739#if K0 == 2
1740#define ARM_DOT_K0(a, b, c) \
1741 ({ \
1742 c = fma(a.s0, b.s0, c); \
1743 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001744 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001745#elif K0 == 3 // K0 == 3
1746#define ARM_DOT_K0(a, b, c) \
1747 ({ \
1748 c = fma(a.s0, b.s0, c); \
1749 c = fma(a.s1, b.s1, c); \
1750 c = fma(a.s2, b.s2, c); \
1751 })
1752#elif K0 == 4 // K0 == 4
1753#define ARM_DOT_K0(a, b, c) \
1754 ({ \
1755 c = fma(a.s0, b.s0, c); \
1756 c = fma(a.s1, b.s1, c); \
1757 c = fma(a.s2, b.s2, c); \
1758 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001759 })
1760#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001761#define ARM_DOT_K0(a, b, c) \
1762 ({ \
1763 c = fma(a.s0, b.s0, c); \
1764 c = fma(a.s1, b.s1, c); \
1765 c = fma(a.s2, b.s2, c); \
1766 c = fma(a.s3, b.s3, c); \
1767 c = fma(a.s4, b.s4, c); \
1768 c = fma(a.s5, b.s5, c); \
1769 c = fma(a.s6, b.s6, c); \
1770 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001771 })
1772#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001773#define ARM_DOT_K0(a, b, c) \
1774 ({ \
1775 c = fma(a.s0, b.s0, c); \
1776 c = fma(a.s1, b.s1, c); \
1777 c = fma(a.s2, b.s2, c); \
1778 c = fma(a.s3, b.s3, c); \
1779 c = fma(a.s4, b.s4, c); \
1780 c = fma(a.s5, b.s5, c); \
1781 c = fma(a.s6, b.s6, c); \
1782 c = fma(a.s7, b.s7, c); \
1783 c = fma(a.s8, b.s8, c); \
1784 c = fma(a.s9, b.s9, c); \
1785 c = fma(a.sA, b.sA, c); \
1786 c = fma(a.sB, b.sB, c); \
1787 c = fma(a.sC, b.sC, c); \
1788 c = fma(a.sD, b.sD, c); \
1789 c = fma(a.sE, b.sE, c); \
1790 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001791 })
1792#else // K0 not supported
1793#error "K0 value not supported"
1794#endif // K0 conditions
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001795#endif // defined(MIXED_PRECISION)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001796
1797#if N0 == 2
1798#define ARM_DOT_K0XN0(a, b, c) \
1799 ({ \
1800 ARM_DOT_K0((a), (b##0), (c.s0)); \
1801 ARM_DOT_K0((a), (b##1), (c.s1)); \
1802 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001803#elif N0 == 3 // N0 == 3
1804#define ARM_DOT_K0XN0(a, b, c) \
1805 ({ \
1806 ARM_DOT_K0((a), (b##0), (c.s0)); \
1807 ARM_DOT_K0((a), (b##1), (c.s1)); \
1808 ARM_DOT_K0((a), (b##2), (c.s2)); \
1809 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001810#elif N0 == 4 // N0 == 4
1811#define ARM_DOT_K0XN0(a, b, c) \
1812 ({ \
1813 ARM_DOT_K0((a), (b##0), (c.s0)); \
1814 ARM_DOT_K0((a), (b##1), (c.s1)); \
1815 ARM_DOT_K0((a), (b##2), (c.s2)); \
1816 ARM_DOT_K0((a), (b##3), (c.s3)); \
1817 })
1818#elif N0 == 8 // N0 == 8
1819#define ARM_DOT_K0XN0(a, b, c) \
1820 ({ \
1821 ARM_DOT_K0((a), (b##0), (c.s0)); \
1822 ARM_DOT_K0((a), (b##1), (c.s1)); \
1823 ARM_DOT_K0((a), (b##2), (c.s2)); \
1824 ARM_DOT_K0((a), (b##3), (c.s3)); \
1825 ARM_DOT_K0((a), (b##4), (c.s4)); \
1826 ARM_DOT_K0((a), (b##5), (c.s5)); \
1827 ARM_DOT_K0((a), (b##6), (c.s6)); \
1828 ARM_DOT_K0((a), (b##7), (c.s7)); \
1829 })
1830#elif N0 == 16 // N0 == 16
1831#define ARM_DOT_K0XN0(a, b, c) \
1832 ({ \
1833 ARM_DOT_K0((a), (b##0), (c.s0)); \
1834 ARM_DOT_K0((a), (b##1), (c.s1)); \
1835 ARM_DOT_K0((a), (b##2), (c.s2)); \
1836 ARM_DOT_K0((a), (b##3), (c.s3)); \
1837 ARM_DOT_K0((a), (b##4), (c.s4)); \
1838 ARM_DOT_K0((a), (b##5), (c.s5)); \
1839 ARM_DOT_K0((a), (b##6), (c.s6)); \
1840 ARM_DOT_K0((a), (b##7), (c.s7)); \
1841 ARM_DOT_K0((a), (b##8), (c.s8)); \
1842 ARM_DOT_K0((a), (b##9), (c.s9)); \
1843 ARM_DOT_K0((a), (b##A), (c.sA)); \
1844 ARM_DOT_K0((a), (b##B), (c.sB)); \
1845 ARM_DOT_K0((a), (b##C), (c.sC)); \
1846 ARM_DOT_K0((a), (b##D), (c.sD)); \
1847 ARM_DOT_K0((a), (b##E), (c.sE)); \
1848 ARM_DOT_K0((a), (b##F), (c.sF)); \
1849 })
1850#else // N0 not supported
1851#error "N0 value not supported"
1852#endif // N0 conditions
1853
1854/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1855 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1856 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1857 *
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001858 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1859 * @note The data type used for the accumulators must be passed at compile time using -DDATA_TYPE_ACCUMULATOR (e.g. -DDATA_TYPE_ACCUMULATOR=float)
1860 * @note The F16 computation also supports mixed precision through the option -DMIXED_PRECISION passed at compile time. If enabled, DATA_TYPE_ACCUMULATOR should be set to float
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001861 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001862 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001863 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
1864 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
1865 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001866 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1867 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1868 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01001869 * - M0 = 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001870 * - N0 = 2, 3, 4, 8, 16
1871 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001872 * - V0 >= 1
1873 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001874 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001875 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001876 * The activation function is performed after the bias addition
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001877 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001878 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1879 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1880 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1881 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1882 *
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001883 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1884 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1885 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1886 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1887 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1888 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1889 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1890 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1891 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1892 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1893 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1894 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1895 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1896 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1897 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1898 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1899 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1900 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1901 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1902 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1903 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1904 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1905 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1906 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001907 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1908 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1909 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1910 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1911 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001912 */
1913__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1914 IMAGE_DECLARATION(rhs),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001915#if defined(BETA)
1916 IMAGE_DECLARATION(bias),
1917#endif // defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001918 IMAGE_DECLARATION(dst),
1919 uint lhs_stride_z,
1920 uint rhs_stride_z,
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001921#if defined(BETA)
1922 uint bias_stride_z,
1923#endif //defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001924 uint dst_stride_z
1925#if defined(REINTERPRET_OUTPUT_AS_3D)
1926 ,
1927 uint dst_cross_plane_pad
1928#endif // REINTERPRET_OUTPUT_AS_3D
1929 )
1930{
1931 // Block size
1932#define LHS_BLOCK_SIZE ((K0) * (M0))
1933
1934#if defined(LHS_INTERLEAVE)
1935#define LHS_OFFSET_X (K0)
1936#define LHS_STEP_X ((K0) * (V0))
1937#define LHS_STEP_LOOP (1)
1938#else // defined(INTERLEAVE)
1939#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1940#define LHS_STEP_X (K0)
1941#define LHS_STEP_LOOP (V0)
1942#endif // defined(INTERLEAVE)
1943
1944 // Block size
1945#define RHS_BLOCK_SIZE ((K0) * (N0))
1946
1947 // RHS offset and step X
1948#if defined(RHS_INTERLEAVE)
1949#define RHS_OFFSET_X (K0)
1950#define RHS_STEP_X ((K0) * (H0))
1951#define RHS_STEP_LOOP (1)
1952#else // defined(RHS_INTERLEAVE)
1953#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1954#define RHS_STEP_X (K0)
1955#define RHS_STEP_LOOP (H0)
1956#endif // defined(RHS_INTERLEAVE)
1957
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001958#if defined(DUMMY_WORK_ITEMS)
1959 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1960 {
1961 return;
1962 }
1963#endif // defined(DUMMY_WORK_ITEMS)
1964
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001965 // Compute LHS matrix address
1966 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1967 (get_global_id(2) * lhs_stride_z);
1968
1969 // Compute RHS matrix address
1970 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1971
1972#if defined(MATRIX_B_DEPTH)
1973 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1974 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1975#else // defined(MATRIX_B_DEPTH)
1976 rhs_addr += get_global_id(2) * rhs_stride_z;
1977#endif // defined(MATRIX_B_DEPTH)
1978
1979 // Initialize the accumulators
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001980 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001981
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001982 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1983 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Usama Arif0681e3b2019-04-25 14:28:07 +01001984
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001985 for(int i = 0; i < K; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001986 {
1987 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001988 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1989 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1990 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1991 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1992 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1993 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1994 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1995 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001996 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001997 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001998
1999 // Load values from RHS matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002000 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002001
2002 // Accumulate
2003 ARM_DOT_K0XN0(a0, b, c0);
2004#if M0 > 1
2005 ARM_DOT_K0XN0(a1, b, c1);
2006#endif // M0 > 1
2007#if M0 > 2
2008 ARM_DOT_K0XN0(a2, b, c2);
2009#endif // M0 > 2
2010#if M0 > 3
2011 ARM_DOT_K0XN0(a3, b, c3);
2012#endif // M0 > 3
2013#if M0 > 4
2014 ARM_DOT_K0XN0(a4, b, c4);
2015#endif // M0 > 4
2016#if M0 > 5
2017 ARM_DOT_K0XN0(a5, b, c5);
2018#endif // M0 > 5
2019#if M0 > 6
2020 ARM_DOT_K0XN0(a6, b, c6);
2021#endif // M0 > 6
2022#if M0 > 7
2023 ARM_DOT_K0XN0(a7, b, c7);
2024#endif // M0 > 7
2025
2026 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
2027 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
2028 }
2029
2030 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2031
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002032 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002033
2034#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002035
2036 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01002037 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002038 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2039 // multiply dst_stride_z by DEPTH_GEMM3D
2040 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2041
2042#else // defined(REINTERPRET_OUTPUT_AS_3D)
2043
2044 // Add offset for batched GEMM
2045 dst_addr += get_global_id(2) * dst_stride_z;
2046
2047#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2048
2049 // Multiply by the weight of matrix-matrix product and store the result
2050#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01002051 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002052#endif // defined(ALPHA)
2053
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002054 // Add beta*bias
2055#if defined(BETA)
2056#if defined(BROADCAST_BIAS)
2057 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2058
2059 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2060
2061#ifndef UNIT_BETA
2062 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2063#endif // UNIT_BIAS
2064
2065 // c = c + bias[broadcasted]
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002066#if defined(MIXED_PRECISION)
2067 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2068 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2069#else // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002070 ADD_BLOCK_BROADCAST(M0, c, bias0);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002071#endif // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002072
2073#else // defined(BROADCAST_BIAS)
2074 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2075 2) * bias_stride_z;
2076
2077 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2078
2079#ifndef UNIT_BETA
2080 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2081#endif // UNIT_BIAS
2082
2083 // c = c + bias
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002084#if defined(MIXED_PRECISION)
2085 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2086 ADD_BLOCK(M0, c, bias_hp);
2087#else // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002088 ADD_BLOCK(M0, c, bias);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002089#endif // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002090
2091#endif // defined(BROADCAST_BIAS)
2092#endif // defined(BETA)
2093
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002094#if defined(ACTIVATION_TYPE)
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002095#if defined(MIXED_PRECISION)
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002096 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002097#else // defined(MIXED_PRECISION)
2098 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2099#endif // defined(MIXED_PRECISION)
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002100#endif // defined(ACTIVATION_TYPE)
2101
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002102 // Store output block
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002103#if defined(MIXED_PRECISION)
2104 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2105#else // defined(MIXED_PRECISION)
Usama Arif0681e3b2019-04-25 14:28:07 +01002106 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002107#endif // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002108
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002109#undef LHS_BLOCK_SIZE
2110#undef LHS_OFFSET_X
2111#undef LHS_STEP_X
2112#undef RHS_BLOCK_SIZE
2113#undef RHS_OFFSET_X
2114#undef RHS_STEP_X
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002115#undef LHS_STEP_LOOP
2116#undef RHS_STEP_LOOP
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002117}
giuros01b3204e72019-04-01 13:50:22 +01002118
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002119#if defined(OPENCL_IMAGE_SUPPORT)
2120/** This OpenCL kernel computes the matrix multiplication between 2 matrices. The RHS matrix is stored in OpenCL image object.
2121 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
2122 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
2123 *
2124 * @note -DOPENCL_IMAGE_SUPPORT must be passed at compile time in order to compile this OpenCL kernel
2125 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
2126 * @note The data type used for the accumulators must be passed at compile time using -DDATA_TYPE_ACCUMULATOR (e.g. -DDATA_TYPE_ACCUMULATOR=float)
2127 * @note The F16 computation also supports mixed precision through the option -DMIXED_PRECISION passed at compile time. If enabled, DATA_TYPE_ACCUMULATOR should be set to float
2128 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2129 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
2130 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2131 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2132 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2133 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2134 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2135 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2136 * - M0 = 2, 3, 4, 5, 6, 7, 8
2137 * - N0 = 4, 8, 16
2138 * - K0 = 4, 8, 16
2139 * - V0 >= 1
2140 * - H0 >= 1
2141 *
2142 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2143 * The activation function is performed after the bias addition
2144 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2145 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2146 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2147 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2148 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2149 *
2150 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F32
2151 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2152 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2153 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2154 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2155 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2156 * @param[in] rhs_img The RHS reshaped matrix as OpenCL image object. Supported data type: same as @p lhs_ptr
2157 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2158 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2159 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2160 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2161 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2162 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2163 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2164 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2165 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2166 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2167 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2168 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2169 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2170 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2171 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2172 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2173 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2174 */
2175__kernel void gemm_mm_reshaped_lhs_nt_rhs_t_texture(IMAGE_DECLARATION(lhs),
2176 __read_only image2d_t rhs_img,
2177#if defined(BETA)
2178 IMAGE_DECLARATION(bias),
2179#endif // defined(BETA)
2180 IMAGE_DECLARATION(dst),
2181 uint lhs_stride_z,
2182 uint rhs_stride_z,
2183#if defined(BETA)
2184 uint bias_stride_z,
2185#endif //defined(BETA)
2186 uint dst_stride_z
2187#if defined(REINTERPRET_OUTPUT_AS_3D)
2188 ,
2189 uint dst_cross_plane_pad
2190#endif // REINTERPRET_OUTPUT_AS_3D
2191 )
2192{
2193 // Pixel unit
2194#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0)
2195
2196 // Block size
2197#define LHS_BLOCK_SIZE ((K0) * (M0))
2198
2199#if defined(LHS_INTERLEAVE)
2200#define LHS_OFFSET_X (K0)
2201#define LHS_STEP_X ((K0) * (V0))
2202#define LHS_STEP_LOOP (1)
2203#else // defined(INTERLEAVE)
2204#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2205#define LHS_STEP_X (K0)
2206#define LHS_STEP_LOOP (V0)
2207#endif // defined(INTERLEAVE)
2208
2209 // Block size
2210#define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0))
2211
2212 // RHS offset and step X
2213#if defined(RHS_INTERLEAVE)
2214#define RHS_OFFSET_X (PIXEL_UNIT)
2215#define RHS_STEP_X (PIXEL_UNIT * (H0))
2216#define RHS_STEP_LOOP (1)
2217#else // defined(RHS_INTERLEAVE)
2218#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2219#define RHS_STEP_X PIXEL_UNIT
2220#define RHS_STEP_LOOP (H0)
2221#endif // defined(RHS_INTERLEAVE)
2222
2223#if defined(DUMMY_WORK_ITEMS)
2224 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
2225 {
2226 return;
2227 }
2228#endif // defined(DUMMY_WORK_ITEMS)
2229
2230 // Compute LHS matrix address
2231 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
2232 (get_global_id(2) * lhs_stride_z);
2233
2234#if defined(MATRIX_B_DEPTH)
2235 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2236 const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
2237#else // defined(MATRIX_B_DEPTH)
2238 const uint z_rhs = get_global_id(2);
2239#endif // defined(MATRIX_B_DEPTH)
2240
2241 // Compute RHS matrix coordinates
2242 uint x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
2243 const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
2244
2245 // Initialize the accumulators
2246 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
2247
2248 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
2249 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
2250
2251 for(int i = 0; i < K; i += K0)
2252 {
2253 // Load values from LHS matrix
2254 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
2255
2256 // Load values from RHS matrix stored in a cl_image
2257 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), b, 0);
2258 LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE, b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
2259
2260 // Accumulate
2261 ARM_DOT_K0XN0(a0, b, c0);
2262#if M0 > 1
2263 ARM_DOT_K0XN0(a1, b, c1);
2264#endif // M0 > 1
2265#if M0 > 2
2266 ARM_DOT_K0XN0(a2, b, c2);
2267#endif // M0 > 2
2268#if M0 > 3
2269 ARM_DOT_K0XN0(a3, b, c3);
2270#endif // M0 > 3
2271#if M0 > 4
2272 ARM_DOT_K0XN0(a4, b, c4);
2273#endif // M0 > 4
2274#if M0 > 5
2275 ARM_DOT_K0XN0(a5, b, c5);
2276#endif // M0 > 5
2277#if M0 > 6
2278 ARM_DOT_K0XN0(a6, b, c6);
2279#endif // M0 > 6
2280#if M0 > 7
2281 ARM_DOT_K0XN0(a7, b, c7);
2282#endif // M0 > 7
2283
2284 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
2285
2286 x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
2287 }
2288
2289 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2290
2291 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
2292
2293#if defined(REINTERPRET_OUTPUT_AS_3D)
2294
2295 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2296 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2297 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2298 // multiply dst_stride_z by DEPTH_GEMM3D
2299 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2300
2301#else // defined(REINTERPRET_OUTPUT_AS_3D)
2302
2303 // Add offset for batched GEMM
2304 dst_addr += get_global_id(2) * dst_stride_z;
2305
2306#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2307
2308 // Multiply by the weight of matrix-matrix product and store the result
2309#if defined(ALPHA)
2310 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2311#endif // defined(ALPHA)
2312
2313 // Add beta*bias
2314#if defined(BETA)
2315#if defined(BROADCAST_BIAS)
2316 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2317
2318 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2319
2320#ifndef UNIT_BETA
2321 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2322#endif // UNIT_BIAS
2323
2324 // c = c + bias[broadcasted]
2325#if defined(MIXED_PRECISION)
2326 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2327 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2328#else // defined(MIXED_PRECISION)
2329 ADD_BLOCK_BROADCAST(M0, c, bias0);
2330#endif // defined(MIXED_PRECISION)
2331
2332#else // defined(BROADCAST_BIAS)
2333 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2334 2) * bias_stride_z;
2335
2336 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2337
2338#ifndef UNIT_BETA
2339 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2340#endif // UNIT_BIAS
2341
2342 // c = c + bias
2343#if defined(MIXED_PRECISION)
2344 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2345 ADD_BLOCK(M0, c, bias_hp);
2346#else // defined(MIXED_PRECISION)
2347 ADD_BLOCK(M0, c, bias);
2348#endif // defined(MIXED_PRECISION)
2349
2350#endif // defined(BROADCAST_BIAS)
2351#endif // defined(BETA)
2352
2353#if defined(ACTIVATION_TYPE)
2354#if defined(MIXED_PRECISION)
2355 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
2356#else // defined(MIXED_PRECISION)
2357 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2358#endif // defined(MIXED_PRECISION)
2359#endif // defined(ACTIVATION_TYPE)
2360
2361 // Store output block
2362#if defined(MIXED_PRECISION)
2363 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2364#else // defined(MIXED_PRECISION)
2365 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2366#endif // defined(MIXED_PRECISION)
2367
2368#undef LHS_BLOCK_SIZE
2369#undef LHS_OFFSET_X
2370#undef LHS_STEP_X
2371#undef RHS_BLOCK_SIZE
2372#undef RHS_OFFSET_X
2373#undef RHS_STEP_X
2374#undef PIXEL_UNIT
2375#undef LHS_STEP_LOOP
2376#undef RHS_STEP_LOOP
2377}
2378#endif // defined(OPENCL_IMAGE_SUPPORT)
2379
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002380#if defined(LHS_TRANSPOSE)
2381
2382#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
2383
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002384#if defined(MIXED_PRECISION)
2385
2386#if(GPU_ARCH == GPU_ARCH_MIDGARD)
2387#define ARM_VFMA(N0, a, b, c) c += (CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))) * (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0)));
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002388#else // GPU_ARCH == GPU_ARCH_MIDGARD
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002389#define ARM_VFMA(N0, a, b, c) c = fma((CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (c));
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002390#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2391
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002392#else // defined(MIXED_PRECISION
2393
2394#if(GPU_ARCH == GPU_ARCH_MIDGARD)
2395#define ARM_VFMA(N0, a, b, c) c += (a) * (b);
2396#else // GPU_ARCH == GPU_ARCH_MIDGARD
2397#define ARM_VFMA(N0, a, b, c) c = fma((a), (b), (c));
2398#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2399
2400#endif // defined(MIXED_PRECISION)
2401
2402#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
2403 ({ \
2404 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002405 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002406#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
2407 ({ \
2408 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
2409 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002410 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002411#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
2412 ({ \
2413 ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
2414 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002415 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002416#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
2417 ({ \
2418 ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
2419 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002420 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002421#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
2422 ({ \
2423 ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
2424 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
2425 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
2426 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
2427 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002428 })
2429
2430// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
2431// a is the column-vector (transposed)
2432// b is the row-vector (not transposed)
2433// C is the output matrix
2434// Lower case is a vector (a, b)
2435// Upper case is a matrix (C)
2436#define ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, a, b, C) ARM_VVM_T_NT_##M0##xN0x1(N0, TYPE, a, b, C)
2437
2438#define ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C) \
2439 ({ \
2440 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##0), (B##0), C); \
2441 })
2442#define ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C) \
2443 ({ \
2444 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C); \
2445 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##1), (B##1), C); \
2446 })
2447#define ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C) \
2448 ({ \
2449 ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C); \
2450 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##2), (B##2), C); \
2451 })
2452#define ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C) \
2453 ({ \
2454 ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C); \
2455 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##3), (B##3), C); \
2456 })
2457#define ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C) \
2458 ({ \
2459 ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C); \
2460 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##4), (B##4), C); \
2461 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##5), (B##5), C); \
2462 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##6), (B##6), C); \
2463 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##7), (B##7), C); \
2464 })
2465#define ARM_MM_T_NT_M0xN0x16(M0, N0, TYPE, A, B, C) \
2466 ({ \
2467 ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C); \
2468 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##8), (B##8), C); \
2469 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##9), (B##9), C); \
2470 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##A), (B##A), C); \
2471 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##B), (B##B), C); \
2472 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##C), (B##C), C); \
2473 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##D), (B##D), C); \
2474 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##E), (B##E), C); \
2475 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##F), (B##F), C); \
2476 })
2477
2478// Factory macro for the matrix (transposed) by matrix (not transposed) multiplication.
2479// The dimensions for this matrix multiplications are defined through M0, N0 and K0
2480// The dimensions supported are:
2481// M0: 1, 2, 3, 4, 8
2482// N0: 1, 2, 3, 4, 8, 16
2483// K0: 1, 2, 3, 4, 8, 16
2484// This macro calls the vector-by-matrix macro K0 times
2485// A, B and C are matrices
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002486#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
2487 CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002488 (M0, N0, TYPE, A, B, C)
2489
2490/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2491 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be transposed
2492 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be NOT transposed
2493 *
2494 * @note LHS_TRANSPOSE should be passed at compile time in order to compile this OpenCL kernel (e.g. -DLHS_TRANSPOSE).
2495 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002496 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002497 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2498 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2499 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2500 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2501 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2502 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2503 * - M0 = 2, 3, 4, 8
2504 * - N0 = 2, 3, 4, 8, 16
2505 * - K0 = 2, 3, 4, 8, 16
2506 * - V0 >= 1
2507 * - H0 >= 1
2508 *
2509 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2510 * The activation function is performed after the bias addition
2511 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2512 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2513 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2514 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2515 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2516 *
2517 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2518 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2519 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2520 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2521 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2522 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2523 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2524 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2525 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2526 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2527 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2528 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2529 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2530 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2531 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2532 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2533 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2534 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2535 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2536 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2537 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2538 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2539 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2540 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002541 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2542 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2543 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2544 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2545 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2546 */
2547__kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
2548 IMAGE_DECLARATION(rhs),
2549#if defined(BETA)
2550 IMAGE_DECLARATION(bias),
2551#endif // defined(BETA)
2552 IMAGE_DECLARATION(dst),
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002553 uint lhs_stride_z,
2554 uint rhs_stride_z,
2555#if defined(BETA)
2556 uint bias_stride_z,
2557#endif //defined(BETA)
2558 uint dst_stride_z
2559#if defined(REINTERPRET_OUTPUT_AS_3D)
2560 ,
2561 uint dst_cross_plane_pad
2562#endif // REINTERPRET_OUTPUT_AS_3D
2563 )
2564{
2565 // Block size
2566#define LHS_BLOCK_SIZE ((K0) * (M0))
2567
2568#if defined(LHS_INTERLEAVE)
2569#define LHS_OFFSET_X (M0)
2570#define LHS_STEP_X ((M0) * (V0))
2571#define LHS_STEP_LOOP (1)
2572#else // defined(INTERLEAVE)
2573#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2574#define LHS_STEP_X (M0)
2575#define LHS_STEP_LOOP (V0)
2576#endif // defined(INTERLEAVE)
2577
2578 // Block size
2579#define RHS_BLOCK_SIZE ((K0) * (N0))
2580
2581 // RHS offset and step X
2582#if defined(RHS_INTERLEAVE)
2583#define RHS_OFFSET_X (N0)
2584#define RHS_STEP_X ((N0) * (H0))
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002585#else // defined(RHS_INTERLEAVE)
2586#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2587#define RHS_STEP_X (N0)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002588#endif // defined(RHS_INTERLEAVE)
2589
2590 const uint x = get_global_id(0);
2591 const uint y = get_global_id(1);
2592 const uint z = get_global_id(2);
2593
2594#if defined(DUMMY_WORK_ITEMS)
2595 if((x * N0 >= N) || (y * M0 >= M))
2596 {
2597 return;
2598 }
2599#endif // defined(DUMMY_WORK_ITEMS)
2600
2601 // Compute LHS matrix address
2602 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2603
2604 // Compute RHS matrix address
2605 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
2606
2607#if defined(MATRIX_B_DEPTH)
2608 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2609 rhs_addr += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2610#else // defined(MATRIX_B_DEPTH)
2611 rhs_addr += z * rhs_stride_z;
2612#endif // defined(MATRIX_B_DEPTH)
2613
2614 // Initialize the accumulators
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002615 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002616
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002617 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2618
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002619 __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
2620 __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
2621
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002622 for(int i = 0; i < K; i += K0)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002623 {
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002624 VEC_DATA_TYPE(DATA_TYPE, M0)
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002625 a0;
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002626 VEC_DATA_TYPE(DATA_TYPE, N0)
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002627 b0;
2628
2629 a0 = VLOAD(M0)(0, lhs);
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002630 b0 = VLOAD(N0)(0, rhs);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002631
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002632 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002633
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002634 lhs += LHS_STEP_X;
2635 rhs += RHS_STEP_X;
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002636
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002637#if K0 > 1
2638 a0 = VLOAD(M0)(0, lhs);
2639 b0 = VLOAD(N0)(0, rhs);
2640
2641 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2642
2643 lhs += LHS_STEP_X;
2644 rhs += RHS_STEP_X;
2645#endif // K0 > 1
2646
2647#if K0 > 2
2648 a0 = VLOAD(M0)(0, lhs);
2649 b0 = VLOAD(N0)(0, rhs);
2650
2651 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2652
2653 lhs += LHS_STEP_X;
2654 rhs += RHS_STEP_X;
2655#endif // K0 > 2
2656
2657#if K0 > 3
2658 a0 = VLOAD(M0)(0, lhs);
2659 b0 = VLOAD(N0)(0, rhs);
2660
2661 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2662
2663 lhs += LHS_STEP_X;
2664 rhs += RHS_STEP_X;
2665#endif // K0 > 3
2666
2667#if K0 > 4
2668 a0 = VLOAD(M0)(0, lhs);
2669 b0 = VLOAD(N0)(0, rhs);
2670
2671 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2672
2673 lhs += LHS_STEP_X;
2674 rhs += RHS_STEP_X;
2675
2676 a0 = VLOAD(M0)(0, lhs);
2677 b0 = VLOAD(N0)(0, rhs);
2678
2679 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2680
2681 lhs += LHS_STEP_X;
2682 rhs += RHS_STEP_X;
2683
2684 a0 = VLOAD(M0)(0, lhs);
2685 b0 = VLOAD(N0)(0, rhs);
2686
2687 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2688
2689 lhs += LHS_STEP_X;
2690 rhs += RHS_STEP_X;
2691
2692 a0 = VLOAD(M0)(0, lhs);
2693 b0 = VLOAD(N0)(0, rhs);
2694
2695 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2696
2697 lhs += LHS_STEP_X;
2698 rhs += RHS_STEP_X;
2699#endif // K0 > 4
2700
2701#if K0 > 8
2702 a0 = VLOAD(M0)(0, lhs);
2703 b0 = VLOAD(N0)(0, rhs);
2704
2705 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2706
2707 lhs += LHS_STEP_X;
2708 rhs += RHS_STEP_X;
2709
2710 a0 = VLOAD(M0)(0, lhs);
2711 b0 = VLOAD(N0)(0, rhs);
2712
2713 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2714
2715 lhs += LHS_STEP_X;
2716 rhs += RHS_STEP_X;
2717
2718 a0 = VLOAD(M0)(0, lhs);
2719 b0 = VLOAD(N0)(0, rhs);
2720
2721 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2722
2723 lhs += LHS_STEP_X;
2724 rhs += RHS_STEP_X;
2725
2726 a0 = VLOAD(M0)(0, lhs);
2727 b0 = VLOAD(N0)(0, rhs);
2728
2729 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2730
2731 lhs += LHS_STEP_X;
2732 rhs += RHS_STEP_X;
2733
2734 a0 = VLOAD(M0)(0, lhs);
2735 b0 = VLOAD(N0)(0, rhs);
2736
2737 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2738
2739 lhs += LHS_STEP_X;
2740 rhs += RHS_STEP_X;
2741
2742 a0 = VLOAD(M0)(0, lhs);
2743 b0 = VLOAD(N0)(0, rhs);
2744
2745 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2746
2747 lhs += LHS_STEP_X;
2748 rhs += RHS_STEP_X;
2749
2750 a0 = VLOAD(M0)(0, lhs);
2751 b0 = VLOAD(N0)(0, rhs);
2752
2753 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2754
2755 lhs += LHS_STEP_X;
2756 rhs += RHS_STEP_X;
2757
2758 a0 = VLOAD(M0)(0, lhs);
2759 b0 = VLOAD(N0)(0, rhs);
2760
2761 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2762
2763 lhs += LHS_STEP_X;
2764 rhs += RHS_STEP_X;
2765#endif // K0 > 8
2766
2767#ifndef LHS_INTERLEAVE
2768 lhs += (M0 * K0 * (V0 - 1));
2769#endif // LHS_INTERLEAVE
2770
2771#ifndef RHS_INTERLEAVE
2772 rhs += (N0 * K0 * (H0 - 1));
2773#endif // RHS_INTERLEAVE
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002774 }
2775
2776 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2777
2778 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
2779
2780#if defined(REINTERPRET_OUTPUT_AS_3D)
2781
2782 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2783 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2784 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2785 // multiply dst_stride_z by DEPTH_GEMM3D
2786 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2787
2788#else // defined(REINTERPRET_OUTPUT_AS_3D)
2789
2790 // Add offset for batched GEMM
2791 dst_addr += z * dst_stride_z;
2792
2793#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2794
2795 // Multiply by the weight of matrix-matrix product and store the result
2796#if defined(ALPHA)
2797 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2798#endif // defined(ALPHA)
2799
2800 // Add beta*bias
2801#if defined(BETA)
2802#if defined(BROADCAST_BIAS)
2803 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE));
2804
2805 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2806
2807#ifndef UNIT_BETA
2808 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2809#endif // UNIT_BIAS
2810
2811 // c = c + bias[broadcasted]
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002812#if defined(MIXED_PRECISION)
2813 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2814 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2815#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002816 ADD_BLOCK_BROADCAST(M0, c, bias0);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002817#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002818
2819#else // defined(BROADCAST_BIAS)
2820 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
2821
2822 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2823
2824#ifndef UNIT_BETA
2825 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2826#endif // UNIT_BIAS
2827
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002828#if defined(MIXED_PRECISION)
2829 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2830 ADD_BLOCK(M0, c, bias_hp);
2831#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002832 ADD_BLOCK(M0, c, bias);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002833#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002834
2835#endif // defined(BROADCAST_BIAS)
2836#endif // defined(BETA)
2837
2838#if defined(ACTIVATION_TYPE)
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002839#if defined(MIXED_PRECISION)
2840 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
2841#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002842 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002843#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002844#endif // defined(ACTIVATION_TYPE)
2845
2846 // Store output block
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002847#if defined(MIXED_PRECISION)
2848 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2849#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002850 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002851#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002852
2853#undef LHS_BLOCK_SIZE
2854#undef LHS_OFFSET_X
2855#undef LHS_STEP_X
2856#undef RHS_BLOCK_SIZE
2857#undef RHS_OFFSET_X
2858#undef RHS_STEP_X
2859}
2860
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002861#if defined(OPENCL_IMAGE_SUPPORT)
2862/** This OpenCL kernel computes the matrix multiplication between 2 matrices. The RHS matrix is stored in OpenCL image object.
2863 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be transposed
2864 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be NOT transposed
2865 *
2866 * @note -DOPENCL_IMAGE_SUPPORT must be passed at compile time in order to compile this OpenCL kernel
2867 * @note LHS_TRANSPOSE should be passed at compile time in order to compile this OpenCL kernel (e.g. -DLHS_TRANSPOSE).
2868 * @note The height of the RHS matrix should be passed at compile time using -DRHS_HEIGHT=<value> (e.g. -DRHS_HEIGHT=32)
2869 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2870 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
2871 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2872 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2873 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2874 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2875 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2876 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2877 * - M0 = 2, 3, 4, 8
2878 * - N0 = 4, 8, 16
2879 * - K0 = 4, 8, 16
2880 * - V0 >= 1
2881 * - H0 >= 1
2882 *
2883 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2884 * The activation function is performed after the bias addition
2885 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2886 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2887 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2888 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2889 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2890 *
2891 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F32
2892 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2893 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2894 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2895 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2896 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2897 * @param[in] rhs_img The RHS reshaped matrix as cl_image 2d. Supported data type: same as @p lhs_ptr
2898 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2899 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2900 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2901 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2902 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2903 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2904 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2905 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2906 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2907 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2908 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2909 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2910 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2911 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2912 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2913 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2914 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2915 */
2916__kernel void gemm_mm_reshaped_lhs_t_rhs_nt_texture(IMAGE_DECLARATION(lhs),
2917 __read_only image2d_t rhs_img,
2918#if defined(BETA)
2919 IMAGE_DECLARATION(bias),
2920#endif // defined(BETA)
2921 IMAGE_DECLARATION(dst),
2922 uint lhs_stride_z,
2923 uint rhs_stride_z,
2924#if defined(BETA)
2925 uint bias_stride_z,
2926#endif //defined(BETA)
2927 uint dst_stride_z
2928#if defined(REINTERPRET_OUTPUT_AS_3D)
2929 ,
2930 uint dst_cross_plane_pad
2931#endif // REINTERPRET_OUTPUT_AS_3D
2932 )
2933{
2934 // Pixel unit
2935#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0)
2936
2937 // Block size
2938#define LHS_BLOCK_SIZE ((K0) * (M0))
2939
2940#if defined(LHS_INTERLEAVE)
2941#define LHS_OFFSET_X (M0)
2942#define LHS_STEP_X ((M0) * (V0))
2943#define LHS_STEP_LOOP (1)
2944#else // defined(INTERLEAVE)
2945#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2946#define LHS_STEP_X (M0)
2947#define LHS_STEP_LOOP (V0)
2948#endif // defined(INTERLEAVE)
2949
2950 // Block size
2951#define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT))
2952
2953 // RHS offset and step X
2954#if defined(RHS_INTERLEAVE)
2955#define RHS_OFFSET_X (PIXEL_UNIT)
2956#define RHS_STEP_X ((PIXEL_UNIT) * (H0))
2957#else // defined(RHS_INTERLEAVE)
2958#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2959#define RHS_STEP_X (PIXEL_UNIT)
2960#endif // defined(RHS_INTERLEAVE)
2961
2962 const uint x = get_global_id(0);
2963 const uint y = get_global_id(1);
2964 const uint z = get_global_id(2);
2965
2966#if defined(DUMMY_WORK_ITEMS)
2967 if((x * N0 >= N) || (y * M0 >= M))
2968 {
2969 return;
2970 }
2971#endif // defined(DUMMY_WORK_ITEMS)
2972
2973 // Compute LHS matrix address
2974 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2975
2976#if defined(MATRIX_B_DEPTH)
2977 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2978 const uint z_rhs = (z % MATRIX_B_DEPTH);
2979#else // defined(MATRIX_B_DEPTH)
2980 const uint z_rhs = z;
2981#endif // defined(MATRIX_B_DEPTH)
2982
2983 // Compute RHS matrix coordinates
2984 uint x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
2985 const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
2986
2987 // Initialize the accumulators
2988 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
2989
2990 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2991
2992 __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
2993
2994 for(int i = 0; i < K; i += K0)
2995 {
2996 VEC_DATA_TYPE(DATA_TYPE, M0)
2997 a0;
2998 VEC_DATA_TYPE(DATA_TYPE, N0)
2999 b0;
3000
3001 a0 = VLOAD(M0)(0, lhs);
3002 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
3003
3004 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3005
3006 lhs += LHS_STEP_X;
3007
3008#if K0 > 1
3009 a0 = VLOAD(M0)(0, lhs);
3010 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 1 * RHS_STEP_X), (y_rhs));
3011
3012 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3013
3014 lhs += LHS_STEP_X;
3015#endif // K0 > 1
3016
3017#if K0 > 2
3018 a0 = VLOAD(M0)(0, lhs);
3019 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 2 * RHS_STEP_X), (y_rhs));
3020
3021 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3022
3023 lhs += LHS_STEP_X;
3024#endif // K0 > 2
3025
3026#if K0 > 3
3027 a0 = VLOAD(M0)(0, lhs);
3028 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 3 * RHS_STEP_X), (y_rhs));
3029
3030 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3031
3032 lhs += LHS_STEP_X;
3033#endif // K0 > 3
3034
3035#if K0 > 4
3036 a0 = VLOAD(M0)(0, lhs);
3037 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 4 * RHS_STEP_X), (y_rhs));
3038
3039 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3040
3041 lhs += LHS_STEP_X;
3042
3043 a0 = VLOAD(M0)(0, lhs);
3044 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 5 * RHS_STEP_X), (y_rhs));
3045
3046 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3047
3048 lhs += LHS_STEP_X;
3049
3050 a0 = VLOAD(M0)(0, lhs);
3051 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 6 * RHS_STEP_X), (y_rhs));
3052
3053 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3054
3055 lhs += LHS_STEP_X;
3056
3057 a0 = VLOAD(M0)(0, lhs);
3058 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 7 * RHS_STEP_X), (y_rhs));
3059
3060 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3061
3062 lhs += LHS_STEP_X;
3063#endif // K0 > 4
3064
3065#if K0 > 8
3066 a0 = VLOAD(M0)(0, lhs);
3067 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 8 * RHS_STEP_X), (y_rhs));
3068
3069 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3070
3071 lhs += LHS_STEP_X;
3072
3073 a0 = VLOAD(M0)(0, lhs);
3074 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 9 * RHS_STEP_X), (y_rhs));
3075
3076 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3077
3078 lhs += LHS_STEP_X;
3079
3080 a0 = VLOAD(M0)(0, lhs);
3081 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 10 * RHS_STEP_X), (y_rhs));
3082
3083 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3084
3085 lhs += LHS_STEP_X;
3086
3087 a0 = VLOAD(M0)(0, lhs);
3088 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 11 * RHS_STEP_X), (y_rhs));
3089
3090 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3091
3092 lhs += LHS_STEP_X;
3093
3094 a0 = VLOAD(M0)(0, lhs);
3095 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 12 * RHS_STEP_X), (y_rhs));
3096
3097 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3098
3099 lhs += LHS_STEP_X;
3100
3101 a0 = VLOAD(M0)(0, lhs);
3102 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 13 * RHS_STEP_X), (y_rhs));
3103
3104 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3105
3106 lhs += LHS_STEP_X;
3107
3108 a0 = VLOAD(M0)(0, lhs);
3109 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 14 * RHS_STEP_X), (y_rhs));
3110
3111 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3112
3113 lhs += LHS_STEP_X;
3114
3115 a0 = VLOAD(M0)(0, lhs);
3116 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 15 * RHS_STEP_X), (y_rhs));
3117
3118 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3119
3120 lhs += LHS_STEP_X;
3121#endif // K0 > 8
3122
3123#ifndef LHS_INTERLEAVE
3124 lhs += (M0 * K0 * (V0 - 1));
3125#endif // LHS_INTERLEAVE
3126
3127 x_rhs += K0 * RHS_STEP_X;
3128#ifndef RHS_INTERLEAVE
3129 x_rhs += (PIXEL_UNIT * K0 * (H0 - 1));
3130#endif // RHS_INTERLEAVE
3131 }
3132
3133 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
3134
3135 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
3136
3137#if defined(REINTERPRET_OUTPUT_AS_3D)
3138
3139 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
3140 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
3141 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3142 // multiply dst_stride_z by DEPTH_GEMM3D
3143 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3144
3145#else // defined(REINTERPRET_OUTPUT_AS_3D)
3146
3147 // Add offset for batched GEMM
3148 dst_addr += z * dst_stride_z;
3149
3150#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3151
3152 // Multiply by the weight of matrix-matrix product and store the result
3153#if defined(ALPHA)
3154 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
3155#endif // defined(ALPHA)
3156
3157 // Add beta*bias
3158#if defined(BETA)
3159#if defined(BROADCAST_BIAS)
3160 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE));
3161
3162 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3163
3164#ifndef UNIT_BETA
3165 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
3166#endif // UNIT_BIAS
3167
3168 // c = c + bias[broadcasted]
3169#if defined(MIXED_PRECISION)
3170 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
3171 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
3172#else // defined(MIXED_PRECISION)
3173 ADD_BLOCK_BROADCAST(M0, c, bias0);
3174#endif // defined(MIXED_PRECISION)
3175
3176#else // defined(BROADCAST_BIAS)
3177 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
3178
3179 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3180
3181#ifndef UNIT_BETA
3182 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
3183#endif // UNIT_BIAS
3184
3185#if defined(MIXED_PRECISION)
3186 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
3187 ADD_BLOCK(M0, c, bias_hp);
3188#else // defined(MIXED_PRECISION)
3189 ADD_BLOCK(M0, c, bias);
3190#endif // defined(MIXED_PRECISION)
3191
3192#endif // defined(BROADCAST_BIAS)
3193#endif // defined(BETA)
3194
3195#if defined(ACTIVATION_TYPE)
3196#if defined(MIXED_PRECISION)
3197 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
3198#else // defined(MIXED_PRECISION)
3199 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
3200#endif // defined(MIXED_PRECISION)
3201#endif // defined(ACTIVATION_TYPE)
3202
3203 // Store output block
3204#if defined(MIXED_PRECISION)
3205 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
3206#else // defined(MIXED_PRECISION)
3207 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
3208#endif // defined(MIXED_PRECISION)
3209
3210#undef LHS_BLOCK_SIZE
3211#undef LHS_OFFSET_X
3212#undef LHS_STEP_X
3213#undef RHS_BLOCK_SIZE
3214#undef RHS_OFFSET_X
3215#undef RHS_STEP_X
3216#undef PIXEL_UNIT
3217#undef LHS_STEP_LOOP
3218#undef RHS_STEP_LOOP
3219}
3220#endif // defined(OPENCL_IMAGE_SUPPORT)
3221
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01003222#endif // defined(LHS_TRANSPOSE)
3223
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00003224#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
3225
giuros01b3204e72019-04-01 13:50:22 +01003226#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
3227
3228#define VFMA(a, b, c) \
3229 ({ \
3230 c = fma(a, b, c); \
3231 })
3232
3233#if M0 == 1
3234#define RHS_VFMA_M0xN0(i, a, b, c) \
3235 ({ \
3236 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3237 })
3238#elif M0 == 2 // M0 == 2
3239#define RHS_VFMA_M0xN0(i, a, b, c) \
3240 ({ \
3241 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3242 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3243 })
3244#elif M0 == 3 // M0 == 3
3245#define RHS_VFMA_M0xN0(i, a, b, c) \
3246 ({ \
3247 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3248 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3249 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3250 })
3251#elif M0 == 4 // M0 == 4
3252#define RHS_VFMA_M0xN0(i, a, b, c) \
3253 ({ \
3254 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3255 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3256 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3257 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3258 })
3259#elif M0 == 5 // M0 == 5
3260#define RHS_VFMA_M0xN0(i, a, b, c) \
3261 ({ \
3262 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3263 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3264 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3265 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3266 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3267 })
3268#elif M0 == 6 // M0 == 6
3269#define RHS_VFMA_M0xN0(i, a, b, c) \
3270 ({ \
3271 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3272 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3273 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3274 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3275 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3276 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3277 })
3278#elif M0 == 7 // M0 == 7
3279#define RHS_VFMA_M0xN0(i, a, b, c) \
3280 ({ \
3281 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3282 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3283 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3284 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3285 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3286 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3287 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
3288 })
3289#elif M0 == 8 // M0 == 8
3290#define RHS_VFMA_M0xN0(i, a, b, c) \
3291 ({ \
3292 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3293 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3294 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3295 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3296 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3297 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3298 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
3299 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
3300 })
3301#else // M0 not supported
3302#error "M0 not supported"
3303#endif // M0 not supported
3304
3305/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
3306 * The LHS matrix is NOT reshaped
3307 * The RHS matrix is NOT reshaped
3308 *
3309 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003310 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
3311 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
3312 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
3313 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (e.g., -DK0=2)
3314 * @note The number of N0 columns to process must be passed at compile time using -DN0 (e.g. -DN0=2)
giuros01b3204e72019-04-01 13:50:22 +01003315 * @note Only the following configurations of M0, N0 and K0 are currently supported:
3316 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
3317 * - N0 = 2, 3, 4, 8, 16
3318 * - K0 = 2, 3, 4, 8, 16
3319 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003320 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01003321 * The activation function is performed after the bias addition
giuros01b3204e72019-04-01 13:50:22 +01003322 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3323 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
3324 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3325 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3326 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3327 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
3328 *
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003329 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
3330 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
3331 * @param[in] lhs_step_x lhs_stride_x * number of elements along X processed per workitem(in bytes)
3332 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
3333 * @param[in] lhs_step_y lhs_stride_y * number of elements along Y processed per workitem(in bytes)
3334 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
3335 * @param[in] rhs_ptr Pointer to the RHS matrix. Supported data type: same as @p lhs_ptr
3336 * @param[in] rhs_stride_x Stride of the RHS matrix in X dimension (in bytes)
3337 * @param[in] rhs_step_x rhs_stride_x * number of elements along X processed per workitem(in bytes)
3338 * @param[in] rhs_stride_y Stride of the RHS matrix in Y dimension (in bytes)
3339 * @param[in] rhs_step_y rhs_stride_y * number of elements along Y processed per workitem(in bytes)
3340 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003341 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3342 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3343 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
3344 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3345 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
3346 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
3347 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
3348 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3349 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3350 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3351 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3352 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3353 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
3354 * @param[in] rhs_stride_z Stride of the RHS matrix in Z dimension (in bytes)
3355 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
3356 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3357 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
3358 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
giuros01b3204e72019-04-01 13:50:22 +01003359 */
3360__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
3361 IMAGE_DECLARATION(rhs),
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003362#if defined(BETA)
3363 IMAGE_DECLARATION(bias),
3364#endif // defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01003365 IMAGE_DECLARATION(dst),
3366 uint lhs_stride_z,
3367 uint rhs_stride_z,
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003368#if defined(BETA)
3369 uint bias_stride_z,
3370#endif //defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01003371 uint dst_stride_z
3372#if defined(REINTERPRET_INPUT_AS_3D)
3373 ,
3374 uint lhs_cross_plane_pad
3375#endif // REINTERPRET_INPUT_AS_3D
3376#if defined(REINTERPRET_OUTPUT_AS_3D)
3377 ,
3378 uint dst_cross_plane_pad
3379#endif // REINTERPRET_OUTPUT_AS_3D
3380 )
3381{
3382 // Block size
3383#define RHS_BLOCK_SIZE ((K0) * (N0))
3384
3385 // RHS offset and step X
3386#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
3387
3388 uint x = get_global_id(0);
3389 uint y = get_global_id(1);
3390 uint z = get_global_id(2);
3391
3392#if defined(DUMMY_WORK_ITEMS)
3393 if((x * N0 >= N) || (y * M0 >= M))
3394 {
3395 return;
3396 }
3397#endif // defined(DUMMY_WORK_ITEMS)
3398
3399 // Compute LHS matrix address
3400 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
3401
3402 // Compute RHS matrix address
3403 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
3404
3405#if defined(MATRIX_B_DEPTH)
3406 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3407 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
3408#else // defined(MATRIX_B_DEPTH)
3409 rhs_offset += z * rhs_stride_z;
3410#endif // defined(MATRIX_B_DEPTH)
3411
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003412 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
3413 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
giuros01b3204e72019-04-01 13:50:22 +01003414
3415#if defined(REINTERPRET_INPUT_AS_3D)
3416 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
3417 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
3418
3419 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3420 // multiply lhs_stride_z by DEPTH_GEMM3D
3421 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
3422
3423#else // defined(REINTERPRET_INPUT_AS_3D)
3424
3425 // Add offset for batched GEMM
3426 lhs_offset += z * lhs_stride_z;
3427
3428#endif // defined(REINTERPRET_INPUT_AS_3D)
3429
3430 // Initialize the accumulators
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003431 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
giuros01b3204e72019-04-01 13:50:22 +01003432
3433 int i = 0;
3434 for(; i <= (K - K0); i += K0)
3435 {
3436 // Supported cases (M0, K0):
3437 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
3438 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
3439 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
3440 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
3441 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
3442 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
3443 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
3444 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
3445 // Load values from LHS matrix
3446 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
3447
3448 // Load values from RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003449 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
giuros01b3204e72019-04-01 13:50:22 +01003450
3451 RHS_VFMA_M0xN0(0, a, b0, c);
3452 RHS_VFMA_M0xN0(1, a, b1, c);
3453#if K0 > 2
3454 RHS_VFMA_M0xN0(2, a, b2, c);
3455#endif // K0 > 2
3456#if K0 > 3
3457 RHS_VFMA_M0xN0(3, a, b3, c);
3458#endif // K0 > 3
3459#if K0 > 4
3460 RHS_VFMA_M0xN0(4, a, b4, c);
3461 RHS_VFMA_M0xN0(5, a, b5, c);
3462 RHS_VFMA_M0xN0(6, a, b6, c);
3463 RHS_VFMA_M0xN0(7, a, b7, c);
3464#endif // K0 > 4
3465#if K0 > 8
3466 RHS_VFMA_M0xN0(8, a, b8, c);
3467 RHS_VFMA_M0xN0(9, a, b9, c);
Gian Marco Iodice7b9d7ca2019-09-19 16:37:39 +01003468 RHS_VFMA_M0xN0(A, a, bA, c);
3469 RHS_VFMA_M0xN0(B, a, bB, c);
3470 RHS_VFMA_M0xN0(C, a, bC, c);
3471 RHS_VFMA_M0xN0(D, a, bD, c);
3472 RHS_VFMA_M0xN0(E, a, bE, c);
3473 RHS_VFMA_M0xN0(F, a, bF, c);
giuros01b3204e72019-04-01 13:50:22 +01003474#endif // K0 > 8
3475
3476 lhs_offset += K0 * sizeof(DATA_TYPE);
3477 rhs_offset += K0 * rhs_stride_y;
3478 }
3479
3480 // Left-over accumulations
3481 for(; i < K; ++i)
3482 {
3483 // Load values from LHS matrix
3484 VEC_DATA_TYPE(DATA_TYPE, 2)
3485 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
3486#if M0 > 1
3487 VEC_DATA_TYPE(DATA_TYPE, 2)
3488 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
3489#endif // M0 > 1
3490#if M0 > 2
3491 VEC_DATA_TYPE(DATA_TYPE, 2)
3492 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
3493#endif // M0 > 2
3494#if M0 > 3
3495 VEC_DATA_TYPE(DATA_TYPE, 2)
3496 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
3497#endif // M0 > 3
3498#if M0 > 4
3499 VEC_DATA_TYPE(DATA_TYPE, 2)
3500 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
3501#endif // M0 > 4
3502#if M0 > 5
3503 VEC_DATA_TYPE(DATA_TYPE, 2)
3504 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
3505#endif // M0 > 5
3506#if M0 > 6
3507 VEC_DATA_TYPE(DATA_TYPE, 2)
3508 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
3509#endif // M0 > 6
3510#if M0 > 7
3511 VEC_DATA_TYPE(DATA_TYPE, 2)
3512 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
3513#endif // M0 > 7
3514
3515 VEC_DATA_TYPE(DATA_TYPE, N0)
3516 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
3517 RHS_VFMA_M0xN0(0, a, b, c);
3518
3519 lhs_offset += sizeof(DATA_TYPE);
3520 rhs_offset += rhs_stride_y;
3521 }
3522
3523 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
3524
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003525 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
giuros01b3204e72019-04-01 13:50:22 +01003526
3527#if defined(REINTERPRET_OUTPUT_AS_3D)
3528 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
3529 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
3530
3531 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3532 // multiply dst_stride_z by DEPTH_GEMM3D
3533 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3534
3535#else // defined(REINTERPRET_OUTPUT_AS_3D)
3536
3537 // Add offset for batched GEMM
3538 dst_addr += z * dst_stride_z;
3539
3540#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3541
3542 // Multiply by the weight of matrix-matrix product and store the result
giuros01b3204e72019-04-01 13:50:22 +01003543#if defined(ALPHA)
3544 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
3545#endif // defined(ALPHA)
3546
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003547 // Add beta*bias
3548#if defined(BETA)
3549#if defined(BROADCAST_BIAS)
3550 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
3551
3552 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3553
3554#ifndef UNIT_BETA
3555 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
3556#endif // UNIT_BIAS
3557
3558 // c = c + bias[broadcasted]
3559 ADD_BLOCK_BROADCAST(M0, c, bias0);
3560
3561#else // defined(BROADCAST_BIAS)
3562 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
3563 2) * bias_stride_z;
3564
3565 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3566
3567#ifndef UNIT_BETA
3568 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
3569#endif // UNIT_BIAS
3570
3571 // c = c + bias
3572 ADD_BLOCK(M0, c, bias);
3573
3574#endif // defined(BROADCAST_BIAS)
3575#endif // defined(BETA)
3576
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01003577#if defined(ACTIVATION_TYPE)
3578 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
3579#endif // defined(ACTIVATION_TYPE)
3580
giuros01b3204e72019-04-01 13:50:22 +01003581 // Store output block
3582 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
3583
3584#undef RHS_BLOCK_SIZE
3585#undef RHS_OFFSET_X
3586#undef RHS_STEP_X
3587}
3588#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
3589
Gian Marco36a0a462018-01-12 10:21:40 +00003590#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003591/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003592 *
Gian Marco19835e52018-01-30 13:35:54 +00003593 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003594 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3595 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3596 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3597 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003598 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003599 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3600 * The activation function is performed after the bias addition
3601 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003602 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3603 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3604 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3605 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3606 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003607 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3608 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3609 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3610 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3611 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3612 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003613 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003614 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3615 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3616 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3617 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3618 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003619 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3620 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3621 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3622 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3623 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3624 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003625 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003626 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003627 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003628 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003629 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003630 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003631 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3632 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003633 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003634 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003635 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003636 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003637__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
3638 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003639#if defined(BETA)
3640 IMAGE_DECLARATION(src2),
3641#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003642 IMAGE_DECLARATION(dst),
3643 uint src0_stride_z,
3644 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003645#if defined(BETA)
3646 uint src2_stride_z,
3647#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003648 uint dst_stride_z
3649#if defined(REINTERPRET_OUTPUT_AS_3D)
3650 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003651 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003652#endif // REINTERPRET_OUTPUT_AS_3D
3653 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003654{
Gian Marco36a0a462018-01-12 10:21:40 +00003655 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3656 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003657 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003658
Gian Marco36a0a462018-01-12 10:21:40 +00003659 // Offset
3660 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3661 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003662
Gian Marco36a0a462018-01-12 10:21:40 +00003663 // src_addr_a = address of matrix A
3664 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003665 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3666 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3667
3668#if defined(MATRIX_B_DEPTH)
3669 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3670 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3671#else // defined(MATRIX_B_DEPTH)
3672 src1_addr_in_bytes += z * src1_stride_z;
3673#endif // defined(MATRIX_B_DEPTH)
3674
3675 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3676 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003677
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003678 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003679 __global float *src_end_addr_b = src_addr_b + COLS_B;
3680
3681 src_addr_a += offset_row_a;
3682 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003683
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003684 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003685 float4 c0 = 0.0f;
3686 float4 c1 = 0.0f;
3687 float4 c2 = 0.0f;
3688 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003689
Gian Marco36a0a462018-01-12 10:21:40 +00003690 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003691 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003692 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003693 float4 a0 = vload4(0, src_addr_a);
3694 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003695
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003696 c0 += (float4)a0.s0 * b0;
3697 c1 += (float4)a0.s1 * b0;
3698 c2 += (float4)a0.s2 * b0;
3699 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003700
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003701 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003702 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3703 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003704
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003705 c0 += (float4)a0.s0 * b0;
3706 c1 += (float4)a0.s1 * b0;
3707 c2 += (float4)a0.s2 * b0;
3708 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003709 }
3710
Gian Marco36a0a462018-01-12 10:21:40 +00003711 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003712 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003713 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003714 float4 a0 = vload4(0, src_addr_a);
3715 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003716
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003717 c0 += (float4)a0.s0 * b0;
3718 c1 += (float4)a0.s1 * b0;
3719 c2 += (float4)a0.s2 * b0;
3720 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003721 }
3722
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003723 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003724 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3725
Gian Marcoae2af742018-02-15 12:35:44 +00003726 // Compute dst address
3727 __global uchar *dst_addr = offset(&dst, 0, 0);
3728
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003729 uint4 zout = 0;
3730
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003731#if defined(REINTERPRET_OUTPUT_AS_3D)
3732 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003733 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003734 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003735 // | |
3736 // | plane0 |
3737 // | |
3738 // |__________________|
3739 // |******************|
3740 // | cross_plane_pad |
3741 // |******************|
3742 // | |
3743 // | plane1 |
3744 // | |
3745 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003746
3747 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003748 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3749 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003750
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003751 // Add offset due to the cross plane paddings
3752 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003753
3754 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3755 // multiply dst_stride_z by DEPTH_GEMM3D
3756 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003757#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003758 // Add offset for batched GEMM
3759 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003760#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3761
3762 // Multiply by the weight of matrix-matrix product and store the result
3763#if defined(ALPHA)
3764 SCALE_BLOCK(4, float, c, ALPHA);
3765#endif // defined(ALPHA)
3766
3767 // Add beta*bias
3768#if defined(BETA)
3769 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3770
3771#if defined(BROADCAST_BIAS)
3772 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
3773
3774 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3775
3776#ifndef UNIT_BETA
3777 SCALE_BLOCK(1, float, bias, BETA);
3778#endif // UNIT_BIAS
3779
3780 // c = c + bias[broadcasted]
3781 ADD_BLOCK_BROADCAST(4, c, bias0);
3782
3783#else // defined(BROADCAST_BIAS)
3784 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3785 2) * src2_stride_z;
3786
3787 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3788
3789#ifndef UNIT_BETA
3790 SCALE_BLOCK(4, float, bias, BETA);
3791#endif // UNIT_BIAS
3792
3793 // c = c + bias
3794 ADD_BLOCK(4, c, bias);
3795
3796#endif // defined(BROADCAST_BIAS)
3797#endif // defined(BETA)
3798
3799#if defined(ACTIVATION_TYPE)
3800 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
3801#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003802
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003803 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003804 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3805 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3806 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3807 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003808}
3809
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003810/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003811 *
Gian Marco19835e52018-01-30 13:35:54 +00003812 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003813 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3814 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3815 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3816 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3817 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003818 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003819 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3820 * The activation function is performed after the bias addition
3821 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003822 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3823 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3824 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3825 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3826 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003827 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3828 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3829 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3830 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3831 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3832 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003833 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003834 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3835 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3836 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3837 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3838 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003839 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3840 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3841 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3842 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3843 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3844 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003845 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003846 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003847 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003848 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003849 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003850 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003851 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3852 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003853 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003854 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003855 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003856 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003857__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
3858 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003859#if defined(BETA)
3860 IMAGE_DECLARATION(src2),
3861#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003862 IMAGE_DECLARATION(dst),
3863 uint src0_stride_z,
3864 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003865#if defined(BETA)
3866 uint src2_stride_z,
3867#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003868 uint dst_stride_z
3869#if defined(REINTERPRET_OUTPUT_AS_3D)
3870 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003871 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003872#endif // REINTERPRET_OUTPUT_AS_3D
3873 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003874{
Gian Marco36a0a462018-01-12 10:21:40 +00003875 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3876 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003877 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00003878
3879 // Offset
3880 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3881 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
3882
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003883 // src_addr_a = address of matrix A
3884 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003885 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3886 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3887
3888#if defined(MATRIX_B_DEPTH)
3889 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3890 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3891#else // defined(MATRIX_B_DEPTH)
3892 src1_addr_in_bytes += z * src1_stride_z;
3893#endif // defined(MATRIX_B_DEPTH)
3894
3895 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3896 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003897
Gian Marco36a0a462018-01-12 10:21:40 +00003898 src_addr_a += offset_row_a;
3899 src_addr_b += offset_row_b;
3900
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003901 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003902 float4 c0 = 0.0f;
3903 float4 c1 = 0.0f;
3904 float4 c2 = 0.0f;
3905 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003906
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003907#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
3908
3909 int i = 0;
3910 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003911 {
3912 // Load values from matrix A (interleaved) and matrix B (transposed)
3913 float4 a0 = vload4(0, src_addr_a);
3914 float4 b0 = vload4(0, src_addr_b);
3915
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003916 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3917 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003918
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003919 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3920 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3921 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3922 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003923
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003924 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3925 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3926 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3927 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003928
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003929 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3930 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3931 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3932 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003933
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003934 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3935 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3936 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3937 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003938
3939 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003940 a0 = vload4(0, src_addr_a);
3941 b0 = vload4(0, src_addr_b);
3942
3943 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3944 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003945
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003946 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3947 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3948 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3949 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003950
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003951 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3952 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3953 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3954 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003955
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003956 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3957 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3958 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3959 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003960
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003961 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3962 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3963 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3964 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003965
3966 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003967 a0 = vload4(0, src_addr_a);
3968 b0 = vload4(0, src_addr_b);
3969
3970 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3971 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3972
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003973 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3974 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3975 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3976 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003977
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003978 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3979 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3980 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3981 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003982
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003983 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3984 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3985 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3986 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003987
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003988 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3989 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3990 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3991 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003992
3993 // Load values from matrix A (interleaved) and matrix B (transposed)
3994 a0 = vload4(0, src_addr_a);
3995 b0 = vload4(0, src_addr_b);
3996
3997 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3998 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003999
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004000 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
4001 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
4002 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
4003 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004004
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004005 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
4006 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
4007 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
4008 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004009
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004010 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
4011 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
4012 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
4013 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004014
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004015 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
4016 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
4017 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
4018 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004019 }
4020
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004021 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004022 {
4023 // Load values from matrix A (interleaved) and matrix B (transposed)
4024 float4 a0 = vload4(0, src_addr_a);
4025 float4 b0 = vload4(0, src_addr_b);
4026
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004027 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4028 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
4029
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004030 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
4031 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
4032 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
4033 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004034
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004035 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
4036 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
4037 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
4038 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004039
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004040 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
4041 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
4042 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
4043 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004044
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004045 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
4046 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
4047 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
4048 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004049 }
4050
4051 // Compute destination address
4052 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4053
Gian Marcoae2af742018-02-15 12:35:44 +00004054 // Compute dst address
4055 __global uchar *dst_addr = offset(&dst, 0, 0);
4056
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004057 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004058
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004059#if defined(REINTERPRET_OUTPUT_AS_3D)
4060 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004061 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004062 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004063 // | |
4064 // | plane0 |
4065 // | |
4066 // |__________________|
4067 // |******************|
4068 // | cross_plane_pad |
4069 // |******************|
4070 // | |
4071 // | plane1 |
4072 // | |
4073 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004074
4075 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004076 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4077 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004078
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004079 // Add offset due to the cross plane paddings
4080 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004081
4082 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4083 // multiply dst_stride_z by DEPTH_GEMM3D
4084 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004085#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00004086 // Add offset for batched GEMM
4087 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004088#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4089
4090 // Multiply by the weight of matrix-matrix product and store the result
4091#if defined(ALPHA)
4092 SCALE_BLOCK(4, float, c, ALPHA);
4093#endif // defined(ALPHA)
4094
4095 // Add beta*bias
4096#if defined(BETA)
4097 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4098
4099#if defined(BROADCAST_BIAS)
4100 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
4101
4102 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4103
4104#ifndef UNIT_BETA
4105 SCALE_BLOCK(1, float, bias, BETA);
4106#endif // UNIT_BIAS
4107
4108 // c = c + bias[broadcasted]
4109 ADD_BLOCK_BROADCAST(4, c, bias0);
4110
4111#else // defined(BROADCAST_BIAS)
4112 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4113 2) * src2_stride_z;
4114
4115 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4116
4117#ifndef UNIT_BETA
4118 SCALE_BLOCK(4, float, bias, BETA);
4119#endif // UNIT_BIAS
4120
4121 // c = c + bias
4122 ADD_BLOCK(4, c, bias);
4123
4124#endif // defined(BROADCAST_BIAS)
4125#endif // defined(BETA)
4126
4127#if defined(ACTIVATION_TYPE)
4128 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
4129#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00004130
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004131 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004132 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
4133 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
4134 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
4135 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004136}
4137
Georgios Pinitas84225582018-05-14 12:00:05 +01004138// Undefine local defines
4139#undef COLS_MTX_B
4140
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01004141#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004142/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004143 *
Gian Marco19835e52018-01-30 13:35:54 +00004144 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004145 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
4146 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
4147 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4148 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004149 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004150 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4151 * The activation function is performed after the bias addition
4152 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004153 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4154 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4155 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4156 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4157 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004158 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4159 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4160 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4161 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4162 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4163 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004164 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004165 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4166 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4167 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4168 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4169 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004170 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4171 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4172 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4173 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4174 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4175 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004176 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004177 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00004178 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004179 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00004180 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004181 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004182 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4183 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004184 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004185 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004186 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004187 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004188__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
4189 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004190#if defined(BETA)
4191 IMAGE_DECLARATION(src2),
4192#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004193 IMAGE_DECLARATION(dst),
4194 uint src0_stride_z,
4195 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004196#if defined(BETA)
4197 uint src2_stride_z,
4198#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004199 uint dst_stride_z
4200#if defined(REINTERPRET_OUTPUT_AS_3D)
4201 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004202 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004203#endif // REINTERPRET_OUTPUT_AS_3D
4204 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004205{
Gian Marco36a0a462018-01-12 10:21:40 +00004206 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
4207 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00004208 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004209
Gian Marco36a0a462018-01-12 10:21:40 +00004210 // Offset
4211 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
4212 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004213
Gian Marco36a0a462018-01-12 10:21:40 +00004214 // src_addr_a = address of matrix A
4215 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004216 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
4217 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
4218
4219#if defined(MATRIX_B_DEPTH)
4220 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4221 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
4222#else // defined(MATRIX_B_DEPTH)
4223 src1_addr_in_bytes += z * src1_stride_z;
4224#endif // defined(MATRIX_B_DEPTH)
4225
4226 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
4227 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004228
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004229 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00004230 __global half *src_end_addr_b = src_addr_b + COLS_B;
4231
4232 src_addr_a += offset_row_a;
4233 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004234
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004235 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004236 half8 c0 = 0.0f;
4237 half8 c1 = 0.0f;
4238 half8 c2 = 0.0f;
4239 half8 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004240
Gian Marco36a0a462018-01-12 10:21:40 +00004241 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004242 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004243 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00004244 half4 a0 = vload4(0, src_addr_a);
4245 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004246
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004247 c0 += (half8)a0.s0 * b0;
4248 c1 += (half8)a0.s1 * b0;
4249 c2 += (half8)a0.s2 * b0;
4250 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004251
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004252 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00004253 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
4254 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004255
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004256 c0 += (half8)a0.s0 * b0;
4257 c1 += (half8)a0.s1 * b0;
4258 c2 += (half8)a0.s2 * b0;
4259 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004260 }
4261
Gian Marco36a0a462018-01-12 10:21:40 +00004262 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004263 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004264 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00004265 half4 a0 = vload4(0, src_addr_a);
4266 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004267
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004268 c0 += (half8)a0.s0 * b0;
4269 c1 += (half8)a0.s1 * b0;
4270 c2 += (half8)a0.s2 * b0;
4271 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004272 }
4273
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004274 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004275 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4276
Gian Marcoae2af742018-02-15 12:35:44 +00004277 // Compute dst address
4278 __global uchar *dst_addr = offset(&dst, 0, 0);
4279
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004280 uint4 zout = 0;
4281
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004282#if defined(REINTERPRET_OUTPUT_AS_3D)
4283 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004284 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004285 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004286 // | |
4287 // | plane0 |
4288 // | |
4289 // |__________________|
4290 // |******************|
4291 // | cross_plane_pad |
4292 // |******************|
4293 // | |
4294 // | plane1 |
4295 // | |
4296 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004297
4298 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004299 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4300 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004301
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004302 // Add offset due to the cross plane paddings
4303 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004304
4305 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4306 // multiply dst_stride_z by DEPTH_GEMM3D
4307 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004308#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00004309 // Add offset for batched GEMM
4310 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004311#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4312
4313 // Multiply by the weight of matrix-matrix product and store the result
4314#if defined(ALPHA)
4315 SCALE_BLOCK(4, half, c, ALPHA);
4316#endif // defined(ALPHA)
4317
4318 // Add beta*bias
4319#if defined(BETA)
4320 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4321
4322#if defined(BROADCAST_BIAS)
4323 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
4324
4325 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4326
4327#ifndef UNIT_BETA
4328 SCALE_BLOCK(1, half, bias, BETA);
4329#endif // UNIT_BIAS
4330
4331 // c = c + bias[broadcasted]
4332 ADD_BLOCK_BROADCAST(4, c, bias0);
4333
4334#else // defined(BROADCAST_BIAS)
4335
4336 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4337 2) * src2_stride_z;
4338
4339 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4340
4341#ifndef UNIT_BETA
4342 SCALE_BLOCK(4, half, bias, BETA);
4343#endif // UNIT_BIAS
4344
4345 // c = c + bias
4346 ADD_BLOCK(4, c, bias);
4347
4348#endif // defined(BROADCAST_BIAS)
4349#endif // defined(BETA)
4350
4351#if defined(ACTIVATION_TYPE)
4352 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
4353#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00004354
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004355 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004356 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4357 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4358 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4359 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004360}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004361
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004362/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004363 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004364 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004365 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
4366 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
4367 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4368 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004369 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004370 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4371 * The activation function is performed after the bias addition
4372 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004373 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4374 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4375 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4376 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4377 *
4378 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4379 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4380 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4381 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4382 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4383 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4384 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4385 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4386 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4387 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4388 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4389 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004390 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4391 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4392 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4393 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4394 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4395 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004396 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4397 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4398 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
4399 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4400 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
4401 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4402 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4403 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004404 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004405 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4406 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
4407 */
4408__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
4409 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004410#if defined(BETA)
4411 IMAGE_DECLARATION(src2),
4412#endif // defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004413 IMAGE_DECLARATION(dst),
4414 uint src0_stride_z,
4415 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004416#if defined(BETA)
4417 uint src2_stride_z,
4418#endif //defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004419 uint dst_stride_z
4420#if defined(REINTERPRET_OUTPUT_AS_3D)
4421 ,
4422 uint cross_plane_pad
4423#endif // REINTERPRET_OUTPUT_AS_3D
4424 )
4425{
4426 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
4427 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
4428 int z = get_global_id(2);
4429
4430 // Offset
4431 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
4432 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
4433
4434 // src_addr_a = address of matrix A
4435 // src_addr_b = address of matrix B
4436 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
4437 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
4438
4439#if defined(MATRIX_B_DEPTH)
4440 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4441 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
4442#else // defined(MATRIX_B_DEPTH)
4443 src1_addr_in_bytes += z * src1_stride_z;
4444#endif // defined(MATRIX_B_DEPTH)
4445
4446 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
4447 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
4448
4449 // Compute end row address for matrix B
4450 __global half *src_end_addr_b = src_addr_b + COLS_B;
4451
4452 src_addr_a += offset_row_a;
4453 src_addr_b += offset_row_b;
4454
4455 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004456 float8 c0 = 0.0f;
4457 float8 c1 = 0.0f;
4458 float8 c2 = 0.0f;
4459 float8 c3 = 0.0f;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004460
4461 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
4462 {
4463 // Load values from matrix A (interleaved) and matrix B (transposed)
4464 float4 a0 = convert_float4(vload4(0, src_addr_a));
4465 float8 b0 = convert_float8(vload8(0, src_addr_b));
4466
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004467 c0 += (float8)a0.s0 * b0;
4468 c1 += (float8)a0.s1 * b0;
4469 c2 += (float8)a0.s2 * b0;
4470 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004471
4472 // Load values from matrix A (interleaved) and matrix B (transposed)
4473 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
4474 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
4475
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004476 c0 += (float8)a0.s0 * b0;
4477 c1 += (float8)a0.s1 * b0;
4478 c2 += (float8)a0.s2 * b0;
4479 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004480 }
4481
4482 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
4483 {
4484 // Load values from matrix A (interleaved) and matrix B (transposed)
4485 float4 a0 = convert_float4(vload4(0, src_addr_a));
4486 float8 b0 = convert_float8(vload8(0, src_addr_b));
4487
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004488 c0 += (float8)a0.s0 * b0;
4489 c1 += (float8)a0.s1 * b0;
4490 c2 += (float8)a0.s2 * b0;
4491 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004492 }
4493
4494 // Compute destination address
4495 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4496
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004497 // Compute dst address
4498 __global uchar *dst_addr = offset(&dst, 0, 0);
4499
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004500 uint4 zout = 0;
4501
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004502#if defined(REINTERPRET_OUTPUT_AS_3D)
4503 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
4504 // in order to take into account the presence of possible cross plane paddings
4505 //
4506 // | |
4507 // | plane0 |
4508 // | |
4509 // |__________________|
4510 // |******************|
4511 // | cross_plane_pad |
4512 // |******************|
4513 // | |
4514 // | plane1 |
4515 // | |
4516 // |__________________|
4517
4518 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004519 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4520 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004521
4522 // Add offset due to the cross plane paddings
4523 zout *= (cross_plane_pad * dst_stride_y);
4524
4525 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4526 // multiply dst_stride_z by DEPTH_GEMM3D
4527 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004528#else // defined(REINTERPRET_OUTPUT_AS_3D)
4529 // Add offset for batched GEMM
4530 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004531#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4532
4533 // Multiply by the weight of matrix-matrix product and store the result
4534#if defined(ALPHA)
4535 SCALE_BLOCK(4, float, c, ALPHA);
4536#endif // defined(ALPHA)
4537
4538#if defined(BETA)
4539 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4540
4541#if defined(BROADCAST_BIAS)
4542 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
4543
4544 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4545
4546 float8 bias_f0 = convert_float8(bias0);
4547
4548#ifndef UNIT_BETA
4549 SCALE_BLOCK(1, float, bias_f, BETA);
4550#endif // UNIT_BIAS
4551
4552 // c = c + bias[broadcasted]
4553 ADD_BLOCK_BROADCAST(4, c, bias_f0);
4554
4555#else // defined(BROADCAST_BIAS)
4556 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4557 2) * src2_stride_z;
4558
4559 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4560
4561 float8 bias_f0 = convert_float8(bias0);
4562 float8 bias_f1 = convert_float8(bias1);
4563 float8 bias_f2 = convert_float8(bias2);
4564 float8 bias_f3 = convert_float8(bias3);
4565
4566#ifndef UNIT_BETA
4567 SCALE_BLOCK(4, float, bias_f, BETA);
4568#endif // UNIT_BIAS
4569
4570 // c = c + bias
4571 ADD_BLOCK(4, c, bias_f);
4572
4573#endif // defined(BROADCAST_BIAS)
4574#endif // defined(BETA)
4575
4576 half8 c_h0 = convert_half8(c0);
4577 half8 c_h1 = convert_half8(c1);
4578 half8 c_h2 = convert_half8(c2);
4579 half8 c_h3 = convert_half8(c3);
4580
4581#if defined(ACTIVATION_TYPE)
4582 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c_h, A_VAL, B_VAL);
4583#endif // defined(ACTIVATION_TYPE)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004584
4585 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004586 vstore8(c_h0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4587 vstore8(c_h1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4588 vstore8(c_h2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4589 vstore8(c_h3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004590}
4591
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004592/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004593 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004594 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004595 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
4596 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
4597 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4598 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004599 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004600 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4601 * The activation function is performed after the bias addition
4602 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004603 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4604 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4605 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4606 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4607 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004608 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4609 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4610 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4611 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4612 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4613 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4614 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4615 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4616 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4617 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4618 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4619 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004620 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4621 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4622 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4623 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4624 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4625 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004626 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4627 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4628 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
4629 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4630 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
4631 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004632 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4633 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4634 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004635 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004636 */
4637__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
4638 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004639#if defined(BETA)
4640 IMAGE_DECLARATION(src2),
4641#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004642 IMAGE_DECLARATION(dst),
4643 uint src0_stride_z,
4644 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004645#if defined(BETA)
4646 uint src2_stride_z,
4647#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004648 uint dst_stride_z
4649#if defined(REINTERPRET_OUTPUT_AS_3D)
4650 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004651 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004652#endif // REINTERPRET_OUTPUT_AS_3D
4653 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004654{
4655 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
4656 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
4657 int z = get_global_id(2);
4658
4659 // Offset
4660 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
4661 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
4662
4663 // src_addr_a = address of matrix A
4664 // src_addr_b = address of matrix B
4665 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
4666 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
4667
4668#if defined(MATRIX_B_DEPTH)
4669 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4670 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
4671#else // defined(MATRIX_B_DEPTH)
4672 src1_addr_in_bytes += z * src1_stride_z;
4673#endif // defined(MATRIX_B_DEPTH)
4674
4675 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
4676 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
4677
4678 // Compute end row address for matrix B
4679 __global half *src_end_addr_b = src_addr_b + COLS_B;
4680
4681 src_addr_a += offset_row_a;
4682 src_addr_b += offset_row_b;
4683
4684 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004685 half8 c0 = 0.0f;
4686 half8 c1 = 0.0f;
4687 half8 c2 = 0.0f;
4688 half8 c3 = 0.0f;
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004689
4690#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
4691
4692 int i = 0;
4693 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
4694 {
4695#if MULT_INTERLEAVE4X4_HEIGHT == 1
4696 // Load values from matrix A (interleaved) and matrix B (transposed)
4697 half8 a0 = vload8(0, src_addr_a);
4698 half8 b0 = vload8(0, src_addr_b);
4699
4700 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
4701 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4702
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004703 c0 = fma((half8)a0.s0, b0, c0);
4704 c1 = fma((half8)a0.s1, b0, c1);
4705 c2 = fma((half8)a0.s2, b0, c2);
4706 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004707
4708 // Load values from matrix B (transposed)
4709 b0 = vload8(0, src_addr_b);
4710
4711 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4712
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004713 c0 = fma((half8)a0.s4, b0, c0);
4714 c1 = fma((half8)a0.s5, b0, c1);
4715 c2 = fma((half8)a0.s6, b0, c2);
4716 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004717
4718 // Load values from matrix A (interleaved) and matrix B (transposed)
4719 a0 = vload8(0, src_addr_a);
4720 b0 = vload8(0, src_addr_b);
4721
4722 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
4723 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4724
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004725 c0 = fma((half8)a0.s0, b0, c0);
4726 c1 = fma((half8)a0.s1, b0, c1);
4727 c2 = fma((half8)a0.s2, b0, c2);
4728 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004729
4730 // Load values from matrix B (transposed)
4731 b0 = vload8(0, src_addr_b);
4732
4733 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4734
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004735 c0 = fma((half8)a0.s4, b0, c0);
4736 c1 = fma((half8)a0.s5, b0, c1);
4737 c2 = fma((half8)a0.s6, b0, c2);
4738 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004739#else // MULT_INTERLEAVE4X4_HEIGHT == 1
4740 // Load values from matrix A (interleaved) and matrix B (transposed)
4741 half4 a0 = vload4(0, src_addr_a);
4742 half8 b0 = vload8(0, src_addr_b);
4743
4744 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4745 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4746
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004747 c0 = fma((half8)a0.s0, b0, c0);
4748 c1 = fma((half8)a0.s1, b0, c1);
4749 c2 = fma((half8)a0.s2, b0, c2);
4750 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004751
4752 // Load values from matrix A (interleaved) and matrix B (transposed)
4753 a0 = vload4(0, src_addr_a);
4754 b0 = vload8(0, src_addr_b);
4755
4756 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4757 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4758
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004759 c0 = fma((half8)a0.s0, b0, c0);
4760 c1 = fma((half8)a0.s1, b0, c1);
4761 c2 = fma((half8)a0.s2, b0, c2);
4762 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004763
4764 // Load values from matrix A (interleaved) and matrix B (transposed)
4765 a0 = vload4(0, src_addr_a);
4766 b0 = vload8(0, src_addr_b);
4767
4768 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4769 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4770
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004771 c0 = fma((half8)a0.s0, b0, c0);
4772 c1 = fma((half8)a0.s1, b0, c1);
4773 c2 = fma((half8)a0.s2, b0, c2);
4774 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004775
4776 // Load values from matrix A (interleaved) and matrix B (transposed)
4777 a0 = vload4(0, src_addr_a);
4778 b0 = vload8(0, src_addr_b);
4779
4780 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4781 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4782
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004783 c0 = fma((half8)a0.s0, b0, c0);
4784 c1 = fma((half8)a0.s1, b0, c1);
4785 c2 = fma((half8)a0.s2, b0, c2);
4786 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004787#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
4788 }
4789
4790 for(; i < (int)(COLS_MTX_B); ++i)
4791 {
4792 // Load values from matrix A (interleaved) and matrix B (transposed)
4793 half4 a0 = vload4(0, src_addr_a);
4794 half8 b0 = vload8(0, src_addr_b);
4795
4796 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4797 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4798
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004799 c0 = fma((half8)a0.s0, b0, c0);
4800 c1 = fma((half8)a0.s1, b0, c1);
4801 c2 = fma((half8)a0.s2, b0, c2);
4802 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004803 }
4804
4805 // Compute destination address
4806 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4807
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004808 // Compute dst address
4809 __global uchar *dst_addr = offset(&dst, 0, 0);
4810
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004811 uint4 zout = 0;
4812
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004813#if defined(REINTERPRET_OUTPUT_AS_3D)
4814 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004815 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004816 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004817 // | |
4818 // | plane0 |
4819 // | |
4820 // |__________________|
4821 // |******************|
4822 // | cross_plane_pad |
4823 // |******************|
4824 // | |
4825 // | plane1 |
4826 // | |
4827 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004828
4829 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004830 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4831 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004832
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004833 // Add offset due to the cross plane paddings
4834 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004835
4836 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4837 // multiply dst_stride_z by DEPTH_GEMM3D
4838 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004839#else // defined(REINTERPRET_OUTPUT_AS_3D)
4840 // Add offset for batched GEMM
4841 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004842#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4843
4844 // Multiply by the weight of matrix-matrix product and store the result
4845#if defined(ALPHA)
4846 SCALE_BLOCK(4, half, c, ALPHA);
4847#endif // defined(ALPHA)
4848
4849 // Add beta*bias
4850#if defined(BETA)
4851 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4852
4853#if defined(BROADCAST_BIAS)
4854 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
4855
4856 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4857
4858#ifndef UNIT_BETA
4859 SCALE_BLOCK(1, half, bias, BETA);
4860#endif // UNIT_BIAS
4861
4862 // c = c + bias[broadcasted]
4863 ADD_BLOCK_BROADCAST(4, c, bias0);
4864
4865#else // defined(BROADCAST_BIAS)
4866 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4867 2) * src2_stride_z;
4868
4869 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4870
4871#ifndef UNIT_BETA
4872 SCALE_BLOCK(4, half, bias, BETA);
4873#endif // UNIT_BIAS
4874
4875 // c = c + bias
4876 ADD_BLOCK(4, c, bias);
4877
4878#endif // defined(BROADCAST_BIAS)
4879#endif // defined(BETA)
4880
4881#if defined(ACTIVATION_TYPE)
4882 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
4883#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004884
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004885 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004886 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4887 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4888 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4889 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004890}
Georgios Pinitas84225582018-05-14 12:00:05 +01004891
4892// Undefine local defines
4893#undef COLS_MTX_B
4894
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01004895#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004896
Gian Marco36a0a462018-01-12 10:21:40 +00004897#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004898
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004899#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
4900#if defined(DATA_TYPE)
4901#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004902/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
4903 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004904 * @note This OpenCL kernel works with floating point data types (F16/F32)
4905 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
4906 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004907 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004908 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4909 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004910 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004911 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4912 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004913 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4914 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004915 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4916 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4917 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4918 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4919 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004920 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004921 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4922 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4923 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4924 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4925 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004926 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004927 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4928 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4929 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4930 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4931 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004932 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4933 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4934 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4935 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4936 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4937 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004938 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004939 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4940 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4941 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4942 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4943 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004944 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4945 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004946 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004947 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004948 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4949 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004950 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004951__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
4952 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004953#if defined(BETA)
4954 IMAGE_DECLARATION(src2),
4955#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004956 IMAGE_DECLARATION(dst),
4957 uint src0_stride_z,
4958 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004959#if defined(BETA)
4960 uint src2_stride_z,
4961#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004962 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004963#if defined(REINTERPRET_INPUT_AS_3D)
4964 ,
4965 uint src_cross_plane_pad
4966#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004967#if defined(REINTERPRET_OUTPUT_AS_3D)
4968 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004969 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004970#endif // REINTERPRET_OUTPUT_AS_3D
4971 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004972{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004973 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004974
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004975 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004976 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004977
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004978 // Update address for the matrix A
4979 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004980
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004981 // Update address for the matrix B
4982 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004983
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004984#if defined(REINTERPRET_INPUT_AS_3D)
4985 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4986 // in order to take into account the presence of possible cross plane paddings
4987 //
4988 // | |
4989 // | plane0 |
4990 // | |
4991 // |__________________|
4992 // |******************|
4993 // | cross_plane_pad |
4994 // |******************|
4995 // | |
4996 // | plane1 |
4997 // | |
4998 // |__________________|
4999
5000 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5001 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5002 zin = min(DEPTH_GEMM3D - 1, zin);
5003
5004 // Add offset due to the cross plane paddings
5005 zin *= (src_cross_plane_pad * src0_stride_y);
5006
5007 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5008 // multiply src0_stride_z by DEPTH_GEMM3D
5009 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5010
5011#else // defined(REINTERPRET_INPUT_AS_3D)
5012
Gian Marcoae2af742018-02-15 12:35:44 +00005013 // Add offset for batched GEMM
5014 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005015
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005016#endif // defined(REINTERPRET_INPUT_AS_3D)
5017
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005018#if defined(MATRIX_B_DEPTH)
5019 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5020 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5021#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005022 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005023#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005024
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005025 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
5026
5027 VECTOR_TYPE acc0 = 0.0f;
5028#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5029 VECTOR_TYPE acc1 = 0.0f;
5030#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5031#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5032 VECTOR_TYPE acc2 = 0.0f;
5033#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5034#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5035 VECTOR_TYPE acc3 = 0.0f;
5036#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5037
Georgios Pinitas96880cf2017-10-20 18:52:20 +01005038 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005039 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005040#if defined(REINTERPRET_INPUT_AS_3D)
5041 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005042 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5043#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005044 // Load values from matrix A
5045 VEC_DATA_TYPE(DATA_TYPE, 2)
5046 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5047#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5048 VEC_DATA_TYPE(DATA_TYPE, 2)
5049 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5050#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5051#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5052 VEC_DATA_TYPE(DATA_TYPE, 2)
5053 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5054#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5055#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5056 VEC_DATA_TYPE(DATA_TYPE, 2)
5057 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5058#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005059#endif // defined(REINTERPRET_INPUT_AS_3D)
5060
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005061 // Load values from matrix B
5062 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
5063 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005064
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005065 // Accumulate
5066 acc0 += b0 * (VECTOR_TYPE)a0.s0;
5067 acc0 += b1 * (VECTOR_TYPE)a0.s1;
5068#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5069 acc1 += b0 * (VECTOR_TYPE)a1.s0;
5070 acc1 += b1 * (VECTOR_TYPE)a1.s1;
5071#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5072#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5073 acc2 += b0 * (VECTOR_TYPE)a2.s0;
5074 acc2 += b1 * (VECTOR_TYPE)a2.s1;
5075#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5076#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5077 acc3 += b0 * (VECTOR_TYPE)a3.s0;
5078 acc3 += b1 * (VECTOR_TYPE)a3.s1;
5079#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005080 }
5081
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005082 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005083 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005084#if defined(REINTERPRET_INPUT_AS_3D)
5085 // Load values from matrix A
5086 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5087#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5088 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5089#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5090#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5091 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5092#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5093#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5094 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5095#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5096#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005097 // Load values from matrix A
5098 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5099#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5100 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5101#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5102#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5103 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5104#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5105#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5106 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5107#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005108#endif // defined(REINTERPRET_INPUT_AS_3D)
5109
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005110 // Load values from matrix B
5111 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005112
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005113 // Accumulate
5114 acc0 += b0 * (VECTOR_TYPE)a0;
5115#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5116 acc1 += b0 * (VECTOR_TYPE)a1;
5117#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5118#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5119 acc2 += b0 * (VECTOR_TYPE)a2;
5120#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5121#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5122 acc3 += b0 * (VECTOR_TYPE)a3;
5123#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005124 }
5125
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005126 int z = get_global_id(2);
5127
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005128 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005129 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5130
Gian Marcoae2af742018-02-15 12:35:44 +00005131 // Compute dst address
5132 __global uchar *dst_addr = offset(&dst, 0, 0);
5133
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005134 uint4 zout = 0;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005135
5136#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005137
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005138 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005139 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005140 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005141 // | |
5142 // | plane0 |
5143 // | |
5144 // |__________________|
5145 // |******************|
5146 // | cross_plane_pad |
5147 // |******************|
5148 // | |
5149 // | plane1 |
5150 // | |
5151 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005152
5153 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005154 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5155 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005156
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005157 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005158 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005159
5160 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5161 // multiply dst_stride_z by DEPTH_GEMM3D
5162 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005163#else // defined(REINTERPRET_OUTPUT_AS_3D)
5164 // Add offset for batched GEMM
5165 dst_addr += z * dst_stride_z;
5166#endif // defined(REINTERPRET_OUTPUT_AS_3D)
5167
5168 // Multiply by the weight of matrix-matrix product and store the result
5169#if defined(ALPHA)
5170 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, acc, ALPHA);
5171#endif // defined(ALPHA)
5172
5173 // Add beta*bias
5174#if defined(BETA)
5175 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5176
5177#if defined(BROADCAST_BIAS)
5178 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE));
5179
5180 LOAD_BLOCK(1, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
5181
5182#ifndef UNIT_BETA
5183 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
5184#endif // UNIT_BIAS
5185
5186 // c = c + bias[broadcasted]
5187 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5188
5189#else // defined(BROADCAST_BIAS)
5190 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE)) + (get_global_id(1) *
5191 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5192
5193 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
5194
5195#ifndef UNIT_BETA
5196 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, bias, BETA);
5197#endif // UNIT_BIAS
5198
5199 // c = c + bias
5200 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5201
5202#endif // defined(BROADCAST_BIAS)
5203#endif // defined(BETA)
5204
5205#if defined(ACTIVATION_TYPE)
5206 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, DATA_TYPE, acc, A_VAL, B_VAL);
5207#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005208
5209 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005210 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005211}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005212#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005213
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01005214/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005215 *
5216 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
5217 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5218 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5219 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5220 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005221 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5222 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005223 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005224 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5225 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005226 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5227 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005228 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5229 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5230 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5231 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5232 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005233 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005234 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5235 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5236 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5237 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5238 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5239 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5240 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5241 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5242 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5243 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5244 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005245 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5246 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5247 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5248 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5249 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5250 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005251 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5252 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5253 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5254 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5255 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5256 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005257 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5258 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005259 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005260 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005261 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5262 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005263 */
5264__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
5265 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005266#if defined(BETA)
5267 IMAGE_DECLARATION(src2),
5268#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00005269 IMAGE_DECLARATION(dst),
5270 uint src0_stride_z,
5271 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005272#if defined(BETA)
5273 uint src2_stride_z,
5274#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005275 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005276#if defined(REINTERPRET_INPUT_AS_3D)
5277 ,
5278 uint src_cross_plane_pad
5279#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005280#if defined(REINTERPRET_OUTPUT_AS_3D)
5281 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005282 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005283#endif // REINTERPRET_OUTPUT_AS_3D
5284 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005285{
5286 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5287
5288 // Compute starting address for matrix A and matrix B
5289 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5290
5291 // Update address for matrix A
5292 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5293
5294 // Update address for matrix B
5295 src_addr.s1 += idx * sizeof(float);
5296
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005297#if defined(REINTERPRET_INPUT_AS_3D)
5298 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5299 // in order to take into account the presence of possible cross plane paddings
5300 //
5301 // | |
5302 // | plane0 |
5303 // | |
5304 // |__________________|
5305 // |******************|
5306 // | cross_plane_pad |
5307 // |******************|
5308 // | |
5309 // | plane1 |
5310 // | |
5311 // |__________________|
5312
5313 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5314 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5315 zin = min(DEPTH_GEMM3D - 1, zin);
5316
5317 // Add offset due to the cross plane paddings
5318 zin *= (src_cross_plane_pad * src0_stride_y);
5319
5320 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5321 // multiply src0_stride_z by DEPTH_GEMM3D
5322 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5323
5324#else // defined(REINTERPRET_INPUT_AS_3D)
5325
Gian Marcoae2af742018-02-15 12:35:44 +00005326 // Add offset for batched GEMM
5327 src_addr.s0 += get_global_id(2) * src0_stride_z;
5328
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005329#endif // defined(REINTERPRET_INPUT_AS_3D)
5330
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005331#if defined(MATRIX_B_DEPTH)
5332 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5333 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5334#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005335 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005336#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005337
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005338 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005339 float4 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005340
5341#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005342 float4 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005343#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5344
5345#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005346 float4 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005347#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5348
5349#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005350 float4 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005351#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5352
5353 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005354 int i = 0;
5355 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005356 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005357#if defined(REINTERPRET_INPUT_AS_3D)
5358 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01005359 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5360#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005361 // Load values from matrix A and matrix B
5362 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005363#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005364 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005365#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5366#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005367 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005368#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5369#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005370 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005371#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005372#endif // defined(REINTERPRET_INPUT_AS_3D)
5373
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005374 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5375 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005376
5377 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005378 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
5379 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
5380 acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
5381 acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005382
5383#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005384
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005385 acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
5386 acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
5387 acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
5388 acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005389
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005390#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5391#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005392
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005393 acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
5394 acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
5395 acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
5396 acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005397
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005398#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5399#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005400
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005401 acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
5402 acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
5403 acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
5404 acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005405#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005406
5407 // Load values from matrix A and matrix B
5408 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5409 src_addr.s1 += src1_stride_y;
5410
5411 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005412 acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
5413 acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
5414 acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
5415 acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005416
5417#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5418
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005419 acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
5420 acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
5421 acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
5422 acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005423
5424#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5425#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5426
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005427 acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
5428 acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
5429 acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
5430 acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005431
5432#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5433#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5434
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005435 acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
5436 acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
5437 acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
5438 acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005439#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5440
5441 // Load values from matrix A and matrix B
5442 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5443 src_addr.s1 += src1_stride_y;
5444
5445 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005446 acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
5447 acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
5448 acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
5449 acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005450
5451#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5452
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005453 acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
5454 acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
5455 acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
5456 acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005457
5458#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5459#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5460
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005461 acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
5462 acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
5463 acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
5464 acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005465
5466#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5467#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5468
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005469 acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
5470 acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
5471 acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
5472 acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005473#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5474
5475 // Load values from matrix A and matrix B
5476 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5477 src_addr.s1 += src1_stride_y;
5478
5479 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005480 acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
5481 acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
5482 acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
5483 acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005484
5485#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5486
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005487 acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
5488 acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
5489 acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
5490 acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005491
5492#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5493#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5494
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005495 acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
5496 acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
5497 acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
5498 acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005499
5500#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5501#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5502
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005503 acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
5504 acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
5505 acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
5506 acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005507#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5508
5509 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005510 }
5511
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005512 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005513 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005514#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005515 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005516 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5517#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5518 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5519#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5520#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5521 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5522#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5523#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5524 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5525#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5526#else // defined(REINTERPRET_INPUT_AS_3D)
5527 // Load values from matrix A
5528 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005529#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5530 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5531#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5532#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5533 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5534#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5535#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5536 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5537#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005538#endif // defined(REINTERPRET_INPUT_AS_3D)
5539
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005540 // Load values from matrix B
5541 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005542 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005543
5544 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005545 acc0.s0 = fma(a0, b0.s0, acc0.s0);
5546 acc0.s1 = fma(a0, b0.s1, acc0.s1);
5547 acc0.s2 = fma(a0, b0.s2, acc0.s2);
5548 acc0.s3 = fma(a0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005549#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005550 acc1.s0 = fma(a1, b0.s0, acc1.s0);
5551 acc1.s1 = fma(a1, b0.s1, acc1.s1);
5552 acc1.s2 = fma(a1, b0.s2, acc1.s2);
5553 acc1.s3 = fma(a1, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005554#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5555#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005556 acc2.s0 = fma(a2, b0.s0, acc2.s0);
5557 acc2.s1 = fma(a2, b0.s1, acc2.s1);
5558 acc2.s2 = fma(a2, b0.s2, acc2.s2);
5559 acc2.s3 = fma(a2, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005560#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5561#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005562 acc3.s0 = fma(a3, b0.s0, acc3.s0);
5563 acc3.s1 = fma(a3, b0.s1, acc3.s1);
5564 acc3.s2 = fma(a3, b0.s2, acc3.s2);
5565 acc3.s3 = fma(a3, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005566#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005567
5568 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005569 }
5570
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005571 int z = get_global_id(2);
5572
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005573 // Compute destination address
5574 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5575
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005576 // Compute dst address
5577 __global uchar *dst_addr = offset(&dst, 0, 0);
5578
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005579 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005580
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005581#if defined(REINTERPRET_OUTPUT_AS_3D)
5582 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005583 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005584 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005585 // | |
5586 // | plane0 |
5587 // | |
5588 // |__________________|
5589 // |******************|
5590 // | cross_plane_pad |
5591 // |******************|
5592 // | |
5593 // | plane1 |
5594 // | |
5595 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005596
5597 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005598 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5599 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005600
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005601 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005602 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005603
5604 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5605 // multiply dst_stride_z by DEPTH_GEMM3D
5606 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005607#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005608 // Add offset for batched GEMM
5609 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005610#endif // defined(REINTERPRET_OUTPUT_AS_3D)
5611
5612 // Multiply by the weight of matrix-matrix product and store the result
5613#if defined(ALPHA)
5614 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
5615#endif // defined(ALPHA)
5616
5617 // Add beta*bias
5618#if defined(BETA)
5619 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5620
5621#if defined(BROADCAST_BIAS)
5622 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
5623
5624 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
5625
5626#ifndef UNIT_BETA
5627 SCALE_BLOCK(1, float, bias, BETA);
5628#endif // UNIT_BIAS
5629
5630 // acc = acc + bias[broadcasted]
5631 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5632
5633#else // defined(BROADCAST_BIAS)
5634 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) *
5635 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5636
5637 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
5638
5639#ifndef UNIT_BETA
5640 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
5641#endif // UNIT_BIAS
5642
5643 // acc = acc + bias
5644 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5645
5646#endif // defined(BROADCAST_BIAS)
5647#endif // defined(BETA)
5648
5649#if defined(ACTIVATION_TYPE)
5650 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
5651#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005652
5653 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005654 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005655#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005656 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005657#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5658#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005659 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005660#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5661#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005662 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005663#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005664}
5665
5666/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
5667 *
5668 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
5669 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
5670 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5671 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
5672 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5673 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005674 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5675 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005676 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005677 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5678 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005679 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5680 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005681 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5682 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5683 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5684 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5685 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005686 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005687 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5688 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5689 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5690 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5691 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5692 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5693 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5694 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5695 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5696 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5697 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005698 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5699 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5700 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5701 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5702 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5703 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005704 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5705 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5706 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5707 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5708 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5709 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005710 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5711 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005712 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005713 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005714 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5715 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005716 */
5717__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
5718 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005719#if defined(BETA)
5720 IMAGE_DECLARATION(src2),
5721#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00005722 IMAGE_DECLARATION(dst),
5723 uint src0_stride_z,
5724 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005725#if defined(BETA)
5726 uint src2_stride_z,
5727#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005728 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005729#if defined(REINTERPRET_INPUT_AS_3D)
5730 ,
5731 uint src_cross_plane_pad
5732#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005733#if defined(REINTERPRET_OUTPUT_AS_3D)
5734 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005735 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005736#endif // REINTERPRET_OUTPUT_AS_3D
5737 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005738{
5739 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5740 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5741
5742 // Compute starting address for matrix A and Matrix B
5743 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5744
5745 // Update address for the matrix A
5746 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5747
5748 // Update address for the matrix B
5749 src_addr.s1 += idx * sizeof(float);
5750
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005751#if defined(REINTERPRET_INPUT_AS_3D)
5752 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5753 // in order to take into account the presence of possible cross plane paddings
5754 //
5755 // | |
5756 // | plane0 |
5757 // | |
5758 // |__________________|
5759 // |******************|
5760 // | cross_plane_pad |
5761 // |******************|
5762 // | |
5763 // | plane1 |
5764 // | |
5765 // |__________________|
5766
5767 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5768 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5769 zin = min(DEPTH_GEMM3D - 1, zin);
5770
5771 // Add offset due to the cross plane paddings
5772 zin *= (src_cross_plane_pad * src0_stride_y);
5773
5774 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5775 // multiply src0_stride_z by DEPTH_GEMM3D
5776 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5777
5778#else // defined(REINTERPRET_INPUT_AS_3D)
5779
Gian Marcoae2af742018-02-15 12:35:44 +00005780 // Add offset for batched GEMM
5781 src_addr.s0 += get_global_id(2) * src0_stride_z;
5782
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005783#endif // defined(REINTERPRET_INPUT_AS_3D)
5784
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005785#if defined(MATRIX_B_DEPTH)
5786 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5787 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5788#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005789 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005790#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005791
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005792 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005793 float2 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005794#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005795 float2 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005796#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5797#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005798 float2 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005799#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5800#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005801 float2 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005802#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5803
5804 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005805 int i = 0;
5806 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005807 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005808#if defined(REINTERPRET_INPUT_AS_3D)
5809 // Load values from matrix A
5810 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
5811#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005812 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005813 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005814#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005815
5816 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005817 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5818 src_addr.s1 += src1_stride_y;
5819 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5820 src_addr.s1 += src1_stride_y;
5821 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5822 src_addr.s1 += src1_stride_y;
5823 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5824 src_addr.s1 += src1_stride_y;
5825 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5826 src_addr.s1 += src1_stride_y;
5827 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5828 src_addr.s1 += src1_stride_y;
5829 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5830 src_addr.s1 += src1_stride_y;
5831 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5832 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005833
5834 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005835 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
5836 acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
5837 acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
5838 acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
5839 acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
5840 acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
5841 acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
5842 acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005843
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005844 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
5845 acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
5846 acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
5847 acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
5848 acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
5849 acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
5850 acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
5851 acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005852
5853#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005854#if defined(REINTERPRET_INPUT_AS_3D)
5855 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5856#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005857 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005858#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005859 acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
5860 acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
5861 acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
5862 acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
5863 acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
5864 acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
5865 acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
5866 acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005867
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005868 acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
5869 acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
5870 acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
5871 acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
5872 acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
5873 acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
5874 acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
5875 acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005876#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5877#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005878#if defined(REINTERPRET_INPUT_AS_3D)
5879 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5880#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005881 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005882#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005883 acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
5884 acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
5885 acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
5886 acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
5887 acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
5888 acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
5889 acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
5890 acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005891
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005892 acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
5893 acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
5894 acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
5895 acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
5896 acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
5897 acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
5898 acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
5899 acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005900#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5901#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005902#if defined(REINTERPRET_INPUT_AS_3D)
5903 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5904#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005905 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005906#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005907 acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
5908 acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
5909 acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
5910 acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
5911 acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
5912 acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
5913 acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
5914 acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005915
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005916 acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
5917 acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
5918 acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
5919 acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
5920 acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
5921 acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
5922 acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
5923 acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005924#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005925
5926 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005927 }
5928 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005929 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005930 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005931#if defined(REINTERPRET_INPUT_AS_3D)
5932 // Load values from matrix A
5933 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5934#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5935 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5936#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5937#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5938 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5939#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5940#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5941 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5942#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5943#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005944 // Load values from matrix A
5945 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5946#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5947 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5948#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5949#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5950 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5951#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5952#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5953 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5954#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005955#endif // defined(REINTERPRET_INPUT_AS_3D)
5956
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005957 // Load values from matrix B
5958 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005959 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005960
5961 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005962 acc0.s0 = fma(a0, b0.s0, acc0.s0);
5963 acc0.s1 = fma(a0, b0.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005964#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005965 acc1.s0 = fma(a1, b0.s0, acc1.s0);
5966 acc1.s1 = fma(a1, b0.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005967#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5968#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005969 acc2.s0 = fma(a2, b0.s0, acc2.s0);
5970 acc2.s1 = fma(a2, b0.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005971#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5972#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005973 acc3.s0 = fma(a3, b0.s0, acc3.s0);
5974 acc3.s1 = fma(a3, b0.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005975#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005976
5977 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005978 }
5979
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005980 int z = get_global_id(2);
5981
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005982 // Compute destination address
5983 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5984
Gian Marcoae2af742018-02-15 12:35:44 +00005985 // Compute dst address
5986 __global uchar *dst_addr = offset(&dst, 0, 0);
5987
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005988 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005989
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005990#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005991
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005992 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005993 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005994 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005995 // | |
5996 // | plane0 |
5997 // | |
5998 // |__________________|
5999 // |******************|
6000 // | cross_plane_pad |
6001 // |******************|
6002 // | |
6003 // | plane1 |
6004 // | |
6005 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00006006
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006007 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006008 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6009 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006010
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006011 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006012 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006013
6014 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6015 // multiply dst_stride_z by DEPTH_GEMM3D
6016 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006017#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006018 // Add offset for batched GEMM
6019 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006020#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6021
6022 // Multiply by the weight of matrix-matrix product and store the result
6023#if defined(ALPHA)
6024 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
6025#endif // defined(ALPHA)
6026
6027 // Add beta*bias
6028#if defined(BETA)
6029 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
6030
6031#if defined(BROADCAST_BIAS)
6032 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
6033
6034 LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
6035
6036#ifndef UNIT_BETA
6037 SCALE_BLOCK(1, float, bias, BETA);
6038#endif // UNIT_BIAS
6039
6040 // acc = acc + bias[broadcasted]
6041 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
6042
6043#else // defined(BROADCAST_BIAS)
6044 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (get_global_id(1) *
6045 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
6046
6047 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
6048
6049#ifndef UNIT_BETA
6050 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
6051#endif // UNIT_BIAS
6052
6053 // acc = acc + bias
6054 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
6055
6056#endif // defined(BROADCAST_BIAS)
6057#endif // defined(BETA)
6058
6059#if defined(ACTIVATION_TYPE)
6060 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
6061#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006062
6063 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006064 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006065#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006066 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006067#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6068#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006069 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006070#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6071#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006072 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006073#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006074}
6075
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01006076#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006077/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
6078 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006079 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
6080 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
6081 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
6082 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
6083 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006084 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
6085 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006086 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006087 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
6088 * The activation function is performed after the bias addition
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006089 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
6090 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
6091 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
6092 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
6093 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
6094 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
6095 *
6096 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
6097 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6098 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6099 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6100 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6101 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
6102 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
6103 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6104 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6105 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6106 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6107 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006108 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
6109 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
6110 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
6111 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
6112 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
6113 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006114 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
6115 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6116 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6117 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6118 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6119 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6120 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
6121 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006122 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006123 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6124 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
6125 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
6126 */
6127__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
6128 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006129#if defined(BETA)
6130 IMAGE_DECLARATION(src2),
6131#endif // defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006132 IMAGE_DECLARATION(dst),
6133 uint src0_stride_z,
6134 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006135#if defined(BETA)
6136 uint src2_stride_z,
6137#endif //defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006138 uint dst_stride_z
6139#if defined(REINTERPRET_INPUT_AS_3D)
6140 ,
6141 uint src_cross_plane_pad
6142#endif // REINTERPRET_INPUT_AS_3D
6143#if defined(REINTERPRET_OUTPUT_AS_3D)
6144 ,
6145 uint dst_cross_plane_pad
6146#endif // REINTERPRET_OUTPUT_AS_3D
6147 )
6148{
6149 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
6150
6151 // Compute starting address for matrix A and Matrix B
6152 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
6153
6154 // Update address for the matrix A
6155 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
6156
6157 // Update address for the matrix B
6158 src_addr.s1 += idx * sizeof(half);
6159
6160#if defined(REINTERPRET_INPUT_AS_3D)
6161 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
6162 // in order to take into account the presence of possible cross plane paddings
6163 //
6164 // | |
6165 // | plane0 |
6166 // | |
6167 // |__________________|
6168 // |******************|
6169 // | cross_plane_pad |
6170 // |******************|
6171 // | |
6172 // | plane1 |
6173 // | |
6174 // |__________________|
6175
6176 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
6177 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6178 zin = min(DEPTH_GEMM3D - 1, zin);
6179
6180 // Add offset due to the cross plane paddings
6181 zin *= (src_cross_plane_pad * src0_stride_y);
6182
6183 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6184 // multiply src0_stride_z by DEPTH_GEMM3D
6185 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
6186
6187#else // defined(REINTERPRET_INPUT_AS_3D)
6188
6189 // Add offset for batched GEMM
6190 src_addr.s0 += get_global_id(2) * src0_stride_z;
6191
6192#endif // defined(REINTERPRET_INPUT_AS_3D)
6193
6194#if defined(MATRIX_B_DEPTH)
6195 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
6196 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
6197#else // defined(MATRIX_B_DEPTH)
6198 src_addr.s1 += get_global_id(2) * src1_stride_z;
6199#endif // defined(MATRIX_B_DEPTH)
6200
6201 float8 acc0 = 0.0h;
6202#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6203 float8 acc1 = 0.0h;
6204#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6205#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6206 float8 acc2 = 0.0h;
6207#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6208#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6209 float8 acc3 = 0.0h;
6210#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6211
6212 int i = 0;
6213 for(; i <= ((int)COLS_A - 4); i += 4)
6214 {
6215#if defined(REINTERPRET_INPUT_AS_3D)
6216 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01006217 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
6218#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006219 // Load values from matrix A
6220 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6221#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6222 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6223#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6224#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6225 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6226#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6227#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6228 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6229#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6230#endif // defined(REINTERPRET_INPUT_AS_3D)
6231
6232 // Load values from matrix B
6233 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6234 src_addr.s1 += src1_stride_y;
6235
6236 // Accumulate
6237 acc0 = fma(b0, (float8)a0.s0, acc0);
6238#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6239 acc1 = fma(b0, (float8)a1.s0, acc1);
6240#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6241#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6242 acc2 = fma(b0, (float8)a2.s0, acc2);
6243#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6244#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6245 acc3 = fma(b0, (float8)a3.s0, acc3);
6246#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6247
6248 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6249 src_addr.s1 += src1_stride_y;
6250 acc0 = fma(b0, (float8)a0.s1, acc0);
6251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6252 acc1 = fma(b0, (float8)a1.s1, acc1);
6253#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6254#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6255 acc2 = fma(b0, (float8)a2.s1, acc2);
6256#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6257#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6258 acc3 = fma(b0, (float8)a3.s1, acc3);
6259#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6260
6261 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6262 src_addr.s1 += src1_stride_y;
6263 acc0 = fma(b0, (float8)a0.s2, acc0);
6264#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6265 acc1 = fma(b0, (float8)a1.s2, acc1);
6266#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6267#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6268 acc2 = fma(b0, (float8)a2.s2, acc2);
6269#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6270#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6271 acc3 = fma(b0, (float8)a3.s2, acc3);
6272#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6273
6274 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6275 src_addr.s1 += src1_stride_y;
6276 acc0 = fma(b0, (float8)a0.s3, acc0);
6277#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6278 acc1 = fma(b0, (float8)a1.s3, acc1);
6279#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6280#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6281 acc2 = fma(b0, (float8)a2.s3, acc2);
6282#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6283#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6284 acc3 = fma(b0, (float8)a3.s3, acc3);
6285#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6286
6287 src_addr.s0 += 4 * sizeof(half);
6288 }
6289
6290 for(; i < (int)COLS_A; ++i)
6291 {
6292#if defined(REINTERPRET_INPUT_AS_3D)
6293 // Load values from matrix A
6294 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
6295#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6296 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
6297#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6298#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6299 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
6300#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6301#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6302 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
6303#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6304#else // defined(REINTERPRET_INPUT_AS_3D)
6305 // Load values from matrix A
6306 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6307#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6308 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6309#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6310#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6311 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6312#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6313#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6314 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6315#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6316#endif // defined(REINTERPRET_INPUT_AS_3D)
6317
6318 // Load values from matrix B
6319 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6320
6321 src_addr += (int2)(sizeof(half), src1_stride_y);
6322
6323 // Accumulate
6324 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
6325#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6326 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
6327#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6328#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6329 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
6330#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6331#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6332 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
6333#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6334 }
6335
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006336 int z = get_global_id(2);
6337
6338 // Compute destination address
6339 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6340
6341 // Compute dst address
6342 __global uchar *dst_addr = offset(&dst, 0, 0);
6343
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006344 uint4 zout = 0;
6345
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006346#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006347
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006348 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
6349 // in order to take into account the presence of possible cross plane paddings
6350 //
6351 // | |
6352 // | plane0 |
6353 // | |
6354 // |__________________|
6355 // |******************|
6356 // | cross_plane_pad |
6357 // |******************|
6358 // | |
6359 // | plane1 |
6360 // | |
6361 // |__________________|
6362
6363 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006364 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6365 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006366
6367 // Add offset due to the cross plane paddings
6368 zout *= (dst_cross_plane_pad * dst_stride_y);
6369
6370 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6371 // multiply dst_stride_z by DEPTH_GEMM3D
6372 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006373#else // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006374 // Add offset for batched GEMM
6375 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006376#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006377
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006378 // Multiply by the weight of matrix-matrix product and store the result
6379#if defined(ALPHA)
6380 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
6381#endif // defined(ALPHA)
6382
6383#if defined(BETA)
6384 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
6385
6386#if defined(BROADCAST_BIAS)
6387 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
6388
6389 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6390
6391 float8 bias_f0 = convert_float8(bias0);
6392
6393#ifndef UNIT_BETA
6394 SCALE_BLOCK(1, float, bias_f, BETA);
6395#endif // UNIT_BIAS
6396
6397 // acc = acc + bias[broadcasted]
6398 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f0);
6399
6400#else // defined(BROADCAST_BIAS)
6401 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
6402 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
6403
6404 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6405
6406 float8 bias_f0 = convert_float8(bias0);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006407#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006408 float8 bias_f1 = convert_float8(bias1);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006409#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6410#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006411 float8 bias_f2 = convert_float8(bias2);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006412#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6413#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006414 float8 bias_f3 = convert_float8(bias3);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006415#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006416
6417#ifndef UNIT_BETA
6418 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias_f, BETA);
6419#endif // UNIT_BIAS
6420
6421 // acc = acc + bias
6422 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f);
6423
6424#endif // defined(BROADCAST_BIAS)
6425#endif // defined(BETA)
6426
6427 half8 acc_h0 = convert_half8(acc0);
6428#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6429 half8 acc_h1 = convert_half8(acc1);
6430#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6431#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6432 half8 acc_h2 = convert_half8(acc2);
6433#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6434#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6435 half8 acc_h3 = convert_half8(acc3);
6436#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6437
6438#if defined(ACTIVATION_TYPE)
6439 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc_h, A_VAL, B_VAL);
6440#endif // defined(ACTIVATION_TYPE)
6441
6442 // Store the output block
6443 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc_h, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006444}
6445
6446/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
6447 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006448 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
6449 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
6450 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
6451 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
6452 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006453 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
6454 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006455 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006456 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
6457 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006458 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
6459 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006460 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
6461 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
6462 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
6463 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
6464 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006465 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
6466 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6467 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6468 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6469 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6470 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
6471 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
6472 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6473 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6474 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6475 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6476 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006477 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
6478 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
6479 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
6480 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
6481 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
6482 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006483 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
6484 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6485 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6486 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6487 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6488 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006489 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
6490 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006491 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006492 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006493 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
6494 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006495 */
6496__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
6497 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006498#if defined(BETA)
6499 IMAGE_DECLARATION(src2),
6500#endif // defined(BETA)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006501 IMAGE_DECLARATION(dst),
6502 uint src0_stride_z,
6503 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006504#if defined(BETA)
6505 uint src2_stride_z,
6506#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006507 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006508#if defined(REINTERPRET_INPUT_AS_3D)
6509 ,
6510 uint src_cross_plane_pad
6511#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006512#if defined(REINTERPRET_OUTPUT_AS_3D)
6513 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006514 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006515#endif // REINTERPRET_OUTPUT_AS_3D
6516 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006517{
6518 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
6519
6520 // Compute starting address for matrix A and Matrix B
6521 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
6522
6523 // Update address for the matrix A
6524 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
6525
6526 // Update address for the matrix B
6527 src_addr.s1 += idx * sizeof(half);
6528
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006529#if defined(REINTERPRET_INPUT_AS_3D)
6530 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
6531 // in order to take into account the presence of possible cross plane paddings
6532 //
6533 // | |
6534 // | plane0 |
6535 // | |
6536 // |__________________|
6537 // |******************|
6538 // | cross_plane_pad |
6539 // |******************|
6540 // | |
6541 // | plane1 |
6542 // | |
6543 // |__________________|
6544
6545 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
6546 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6547 zin = min(DEPTH_GEMM3D - 1, zin);
6548
6549 // Add offset due to the cross plane paddings
6550 zin *= (src_cross_plane_pad * src0_stride_y);
6551
6552 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6553 // multiply src0_stride_z by DEPTH_GEMM3D
6554 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
6555
6556#else // defined(REINTERPRET_INPUT_AS_3D)
6557
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006558 // Add offset for batched GEMM
6559 src_addr.s0 += get_global_id(2) * src0_stride_z;
6560
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006561#endif // defined(REINTERPRET_INPUT_AS_3D)
6562
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006563#if defined(MATRIX_B_DEPTH)
6564 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
6565 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
6566#else // defined(MATRIX_B_DEPTH)
6567 src_addr.s1 += get_global_id(2) * src1_stride_z;
6568#endif // defined(MATRIX_B_DEPTH)
6569
6570 half8 acc0 = 0.0h;
6571#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6572 half8 acc1 = 0.0h;
6573#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6574#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6575 half8 acc2 = 0.0h;
6576#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6577#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6578 half8 acc3 = 0.0h;
6579#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6580
6581 int i = 0;
6582 for(; i <= ((int)COLS_A - 4); i += 4)
6583 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006584#if defined(REINTERPRET_INPUT_AS_3D)
6585 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01006586 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
6587#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006588 // Load values from matrix A
6589 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6590#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6591 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6592#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6593#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6594 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6595#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6596#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6597 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6598#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006599#endif // defined(REINTERPRET_INPUT_AS_3D)
6600
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006601 // Load values from matrix B
6602 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6603 src_addr.s1 += src1_stride_y;
6604
6605 // Accumulate
6606 acc0 = fma(b0, (half8)a0.s0, acc0);
6607#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6608 acc1 = fma(b0, (half8)a1.s0, acc1);
6609#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6610#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6611 acc2 = fma(b0, (half8)a2.s0, acc2);
6612#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6613#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6614 acc3 = fma(b0, (half8)a3.s0, acc3);
6615#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6616
6617 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6618 src_addr.s1 += src1_stride_y;
6619 acc0 = fma(b0, (half8)a0.s1, acc0);
6620#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6621 acc1 = fma(b0, (half8)a1.s1, acc1);
6622#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6623#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6624 acc2 = fma(b0, (half8)a2.s1, acc2);
6625#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6626#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6627 acc3 = fma(b0, (half8)a3.s1, acc3);
6628#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6629
6630 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6631 src_addr.s1 += src1_stride_y;
6632 acc0 = fma(b0, (half8)a0.s2, acc0);
6633#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6634 acc1 = fma(b0, (half8)a1.s2, acc1);
6635#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6636#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6637 acc2 = fma(b0, (half8)a2.s2, acc2);
6638#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6639#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6640 acc3 = fma(b0, (half8)a3.s2, acc3);
6641#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6642
6643 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6644 src_addr.s1 += src1_stride_y;
6645 acc0 = fma(b0, (half8)a0.s3, acc0);
6646#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6647 acc1 = fma(b0, (half8)a1.s3, acc1);
6648#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6649#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6650 acc2 = fma(b0, (half8)a2.s3, acc2);
6651#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6652#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6653 acc3 = fma(b0, (half8)a3.s3, acc3);
6654#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6655
6656 src_addr.s0 += 4 * sizeof(half);
6657 }
6658
6659 for(; i < (int)COLS_A; ++i)
6660 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006661#if defined(REINTERPRET_INPUT_AS_3D)
6662 // Load values from matrix A
6663 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
6664#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6665 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
6666#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6667#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6668 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
6669#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6670#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6671 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
6672#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6673#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006674 // Load values from matrix A
6675 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6676#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6677 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6678#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6679#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6680 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6681#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6682#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6683 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6684#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006685#endif // defined(REINTERPRET_INPUT_AS_3D)
6686
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006687 // Load values from matrix B
6688 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6689
6690 src_addr += (int2)(sizeof(half), src1_stride_y);
6691
6692 // Accumulate
6693 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
6694#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6695 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
6696#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6697#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6698 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
6699#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6700#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6701 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
6702#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6703 }
6704
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006705 int z = get_global_id(2);
6706
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006707 // Compute destination address
6708 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6709
6710 // Compute dst address
6711 __global uchar *dst_addr = offset(&dst, 0, 0);
6712
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006713 uint4 zout = 0;
6714
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006715#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006716
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006717 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006718 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006719 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006720 // | |
6721 // | plane0 |
6722 // | |
6723 // |__________________|
6724 // |******************|
6725 // | cross_plane_pad |
6726 // |******************|
6727 // | |
6728 // | plane1 |
6729 // | |
6730 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006731
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006732 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006733 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6734 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006735
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006736 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006737 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006738
6739 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6740 // multiply dst_stride_z by DEPTH_GEMM3D
6741 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006742#else // defined(REINTERPRET_OUTPUT_AS_3D)
6743 // Add offset for batched GEMM
6744 dst_addr += z * dst_stride_z;
6745#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6746
6747 // Multiply by the weight of matrix-matrix product and store the result
6748#if defined(ALPHA)
6749 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, acc, ALPHA);
6750#endif // defined(ALPHA)
6751
6752 // Add beta*bias
6753#if defined(BETA)
6754 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
6755
6756#if defined(BROADCAST_BIAS)
6757 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
6758
6759 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6760
6761#ifndef UNIT_BETA
6762 SCALE_BLOCK(1, half, bias, BETA);
6763#endif // UNIT_BIAS
6764
6765 // acc = acc + bias[broadcasted]
6766 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
6767
6768#else // defined(BROADCAST_BIAS)
6769 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
6770 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
6771
6772 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6773
6774#ifndef UNIT_BETA
6775 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, bias, BETA);
6776#endif // UNIT_BIAS
6777
6778 // acc = acc + bias
6779 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
6780
6781#endif // defined(BROADCAST_BIAS)
6782#endif // defined(BETA)
6783
6784#if defined(ACTIVATION_TYPE)
6785 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc, A_VAL, B_VAL);
6786#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006787
6788 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01006789 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006790}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01006791#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006792
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01006793#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006794
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006795#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006796/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6797 *
Gian Marco19835e52018-01-30 13:35:54 +00006798 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006799 *
6800 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
6801 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6802 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6803 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6804 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006805 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6806 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006807 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006808 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006809 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6810 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6811 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6812 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006813 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6814 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006815 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6816 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006817__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
6818 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006819{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006820 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006821 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6822 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006823
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006824 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006825 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
6826
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006827 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006828 float4 c = vload4(0, (__global float *)src.ptr);
6829
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006830 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006831 float4 out = alpha_ab + (float4)BETA * c;
6832
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006833 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006834 vstore4(out, 0, (__global float *)dst.ptr);
6835}
6836
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006837#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006838/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6839 *
Gian Marco19835e52018-01-30 13:35:54 +00006840 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006841 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006842 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
6843 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6844 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6845 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6846 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006847 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6848 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006849 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006850 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006851 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6852 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6853 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6854 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006855 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6856 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006857 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6858 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006859__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
6860 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006861{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006862 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006863 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6864 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006865
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006866 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006867 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
6868
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006869 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006870 half8 c = vload8(0, (__global half *)src.ptr);
6871
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006872 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006873 half8 out = alpha_ab + (half8)BETA * c;
6874
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006875 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006876 vstore8(out, 0, (__global half *)dst.ptr);
6877}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006878#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006879#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006880
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006881#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006882/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
6883 *
Gian Marco19835e52018-01-30 13:35:54 +00006884 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006885 *
Gian Marco19835e52018-01-30 13:35:54 +00006886 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006887 *
6888 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
6889 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6890 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6891 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6892 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6893 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006894 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006895 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6896 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6897 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6898 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6899 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
6900 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
6901 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006902 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006903 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6904 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6905 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6906 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6907 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6908 */
6909__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
6910 TENSOR3D_DECLARATION(src1),
6911 IMAGE_DECLARATION(dst))
6912{
6913 int idx = get_global_id(0) * 4;
6914 int idy = get_global_id(1);
6915
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006916 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006917 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
6918 src_addr.s1 += idx * sizeof(float);
6919
6920 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
6921
6922 float4 acc = 0.0f;
6923
Georgios Pinitas96880cf2017-10-20 18:52:20 +01006924 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006925 {
6926 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
6927 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6928 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
6929
6930 acc += b0 * (float4)a0.s0;
6931 acc += b1 * (float4)a0.s1;
6932 }
6933
6934 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
6935 {
6936 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
6937 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6938
6939 acc += b0 * (float4)a0;
6940 }
6941
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006942 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006943 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6944
6945 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
6946}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006947#endif // defined(WIDTH_VECTOR_A)
6948
6949/** This kernel accumulates each row with the biases vector.
6950 *
6951 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
6952 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
6953 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006954 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006955 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
6956 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
6957 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
6958 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6959 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
6960 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
6961 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
6962 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
6963 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
6964 */
6965#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
6966__kernel void gemm_accumulate_biases(
6967 IMAGE_DECLARATION(accum),
6968 VECTOR_DECLARATION(biases))
6969{
6970 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
6971 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
6972
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006973 // Vector size, e.g. number of vector elements.
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006974 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6975 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
6976 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6977 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006978 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006979 // Store result in the accumulate buffer
6980 VSTORE(VECTOR_SIZE)
6981 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
6982}
6983#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)