blob: b0b8b2c6b03da1125201d3ad679fb4b2d4ba4ff9 [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Sheri Zhang1a378102020-04-30 12:59:39 +01002 * Copyright (c) 2017-2020 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Usama Arif0681e3b2019-04-25 14:28:07 +010024#include "gemm_helpers.h"
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +000025#include "repeat.h"
Anthony Barbier6ff3b192017-09-04 18:44:23 +010026
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +000027#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
28#define INC2 (VEC_DATA_TYPE(uint, 2))(0, 1)
29#define INC3 (VEC_DATA_TYPE(uint, 3))(0, 1, 2)
30#define INC4 (VEC_DATA_TYPE(uint, 4))(0, 1, 2, 3)
31#define INC8 (VEC_DATA_TYPE(uint, 8))(0, 1, 2, 3, 4, 5, 6, 7)
32#define INC16 (VEC_DATA_TYPE(uint, 16))(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)
33#define CONCAT_INC(K0) INC##K0
34#define INC(K0) CONCAT_INC(K0)
35
36#if(SRC_WIDTH % K0)
37#define BOUNDARY_CONDITION_X(x, a) \
38 ({ \
39 a = select(0, a, CONVERT(((x * (VEC_DATA_TYPE(uint, K0))K0 + INC(K0)) < (VEC_DATA_TYPE(uint, K0))SRC_WIDTH), VEC_DATA_TYPE(DATA_TYPE, K0))); \
40 })
41#else // (SRC_WIDTH % K0)
42#define BOUNDARY_CONDITION_X(x, a) \
43 ({})
44#endif // (SRC_WIDTH % K0)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000045
46/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
47 * the output matrix unrolling the values.
48 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010049 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
50 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
51 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
52 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000053 * @note Only the following values for M0, K0 and V0 are supported:
54 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +000055 * K0: 2,3,4,8,16
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000056 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +010057 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000058 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
59 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
60 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
61 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
62 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
63 *
64 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
65 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
66 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
68 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
70 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
71 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
72 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
73 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
74 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
75 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
76 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
79 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
80 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
81 */
82__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
83 TENSOR3D_DECLARATION(dst)
84#if defined(REINTERPRET_INPUT_AS_3D)
85 ,
86 uint cross_plane_pad
87#endif // REINTERPRET_INPUT_AS_3D
88 )
89{
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000090 // Block size
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000091#define BLOCK_SIZE ((M0) * (K0))
92
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +000093 // Output offset X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000094#if defined(INTERLEAVE)
95#define OUTPUT_OFFSET_X (K0)
96#else // defined(INTERLEAVE)
97#define OUTPUT_OFFSET_X (BLOCK_SIZE)
98#endif // defined(INTERLEAVE)
99
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000100 // Output step X
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000101#if defined(INTERLEAVE)
102#define OUTPUT_STEP_X (K0) * (V0)
103#else // Do not interleave
104#define OUTPUT_STEP_X (K0)
105#endif // defined(INTERLEAVE)
106
107 // Compute source and destination addresses
108 uint x = get_global_id(0);
109 uint y = get_global_id(1);
110 uint z = get_global_id(2);
111
112 // ------------------ Compute input/output addresses ---------------------------
113
114 // Compute the input address
115 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
116
117 // Compute the output address
118 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
119 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
120
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000121 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
122 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000123
124#if defined(REINTERPRET_INPUT_AS_3D)
125 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
126 // multiply src_stride_z by DEPTH_GEMM3D
127
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000128 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
129
130 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100131 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000132
133#else // defined(REINTERPRET_INPUT_AS_3D)
134
135 input_ptr += z * (uint)src_stride_z;
136
137#endif // defined(REINTERPRET_INPUT_AS_3D)
138
139 // Add offset for batched GEMM
140 output_ptr += z * (uint)dst_stride_z;
141
142 // ---------------------------Load input values --------------------------------
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000143 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100144 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000145 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000146#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000147 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000148#endif // M0 > 1
149#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000150 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000151#endif // M0 > 2
152#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000153 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000154#endif // M0 > 3
155#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000156 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000157#endif // M0 > 4
158#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000159 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000160#endif // M0 > 5
161#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000162 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000163#endif // M0 > 6
164#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000165 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000166#endif // M0 > 7
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000167 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100168 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
169 STORE_BLOCK(M0, K0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000170
171#undef BLOCK_SIZE
172#undef OUTPUT_OFFSET_X
173#undef OUTPUT_STEP_X
174}
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000175
176#if M0 == 2
177#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
178 ({ \
179 VEC_DATA_TYPE(DATA_TYPE, M0) \
180 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i); \
181 VSTORE(M0) \
182 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
183 })
184#elif M0 == 3 // M0 == 3
185#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
186 ({ \
187 VEC_DATA_TYPE(DATA_TYPE, M0) \
188 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i); \
189 VSTORE(M0) \
190 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
191 })
192#elif M0 == 4 // M0 == 4
193#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
194 ({ \
195 VEC_DATA_TYPE(DATA_TYPE, M0) \
196 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
197 VSTORE(M0) \
198 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
199 })
200#elif M0 == 5 // M0 == 5
201#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
202 ({ \
203 VEC_DATA_TYPE(DATA_TYPE, 4) \
204 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
205 DATA_TYPE res1 = a4.s##i; \
206 VSTORE(4) \
207 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
208 *((__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4) = res1; \
209 })
210#elif M0 == 6 // M0 == 6
211#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
212 ({ \
213 VEC_DATA_TYPE(DATA_TYPE, 4) \
214 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
215 VEC_DATA_TYPE(DATA_TYPE, 2) \
216 res1 = (VEC_DATA_TYPE(DATA_TYPE, 2))(a4.s##i, a5.s##i); \
217 VSTORE(4) \
218 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
219 VSTORE(2) \
220 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
221 })
222#elif M0 == 7 // M0 == 7
223#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
224 ({ \
225 VEC_DATA_TYPE(DATA_TYPE, 4) \
226 res0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s##i, a1.s##i, a2.s##i, a3.s##i); \
227 VEC_DATA_TYPE(DATA_TYPE, 3) \
228 res1 = (VEC_DATA_TYPE(DATA_TYPE, 3))(a4.s##i, a5.s##i, a6.s##i); \
229 VSTORE(4) \
230 (res0, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
231 VSTORE(3) \
232 (res1, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE)) + 4); \
233 })
234#elif M0 == 8 // M0 == 8
235#define TRANSPOSE_COLUMN_AND_STORE(output_ptr, output_step_x, i) \
236 ({ \
237 VEC_DATA_TYPE(DATA_TYPE, M0) \
238 res = (VEC_DATA_TYPE(DATA_TYPE, M0))(a0.s##i, a1.s##i, a2.s##i, a3.s##i, a4.s##i, a5.s##i, a6.s##i, a7.s##i); \
239 VSTORE(M0) \
240 (res, 0, (__global DATA_TYPE *)(output_ptr + 0x##i * output_step_x * sizeof(DATA_TYPE))); \
241 })
242#else // M0 not supported
243#error "M0 value not supported"
244#endif // N0 conditions
245
246/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (transposed) in
247 * the output matrix unrolling the values.
248 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100249 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
250 * @note The width of the input tensor must be passed at compile time using -DSRC_WIDTH (e.g. -DSRC_WIDTH=16)
251 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (e.g. -DM0=2, -DK0=2).
252 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (e.g. -DV0=2)
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000253 * @note Only the following values for M0, K0 and V0 are supported:
254 * M0: 2,3,4,5,6,7,8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000255 * K0: 2,3,4,8,16
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000256 * V0: greater than 0
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100257 * @note In case the input has to be reinterpreted as a 3D tensor (e.g. input of convolution layer 1x1), the following information must be passed at compile time:
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000258 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
259 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
260 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
261 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
262 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
263 *
264 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
265 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
267 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
269 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
270 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
272 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
278 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
280 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
281 */
282__kernel void gemm_reshape_lhs_matrix_t(TENSOR3D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst)
284#if defined(REINTERPRET_INPUT_AS_3D)
285 ,
286 uint cross_plane_pad
287#endif // REINTERPRET_INPUT_AS_3D
288 )
289{
290 // Block size
291#define BLOCK_SIZE ((M0) * (K0))
292
293 // Output offset X
294#if defined(INTERLEAVE)
295#define OUTPUT_OFFSET_X (M0)
296#else // defined(INTERLEAVE)
297#define OUTPUT_OFFSET_X (BLOCK_SIZE)
298#endif // defined(INTERLEAVE)
299
300 // Output step X
301#if defined(INTERLEAVE)
302#define OUTPUT_STEP_X (M0) * (V0)
303#else // Do not interleave
304#define OUTPUT_STEP_X (M0)
305#endif // defined(INTERLEAVE)
306
307 // Compute source and destination addresses
308 uint x = get_global_id(0);
309 uint y = get_global_id(1);
310 uint z = get_global_id(2);
311
312 // ------------------ Compute input/output addresses ---------------------------
313
314 // Compute the input address
315 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
316
317 // Compute the output address
318 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
319 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
320
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000321 // Create variables: uint zin0=0, zin1=0, zin2=0...zin(M0-1)=0;
322 REPEAT_VAR_INIT_TO_CONST(M0, uint, zin, 0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000323
324#if defined(REINTERPRET_INPUT_AS_3D)
325 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
326 // multiply src_stride_z by DEPTH_GEMM3D
327
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000328 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
329
330 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +0100331 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, cross_plane_pad, src_stride_y);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000332
333#else // defined(REINTERPRET_INPUT_AS_3D)
334
335 input_ptr += z * (uint)src_stride_z;
336
337#endif // defined(REINTERPRET_INPUT_AS_3D)
338
339 // Add offset for batched GEMM
340 output_ptr += z * (uint)dst_stride_z;
341
342 // ---------------------------Load input values --------------------------------
343
344 // Load values from the LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +0100345 LOAD_BLOCK(M0, K0, DATA_TYPE, a, input_ptr, 0, src_stride_y, zin);
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000346 BOUNDARY_CONDITION_X(x, a0);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000347#if M0 > 1
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000348 BOUNDARY_CONDITION_X(x, a1);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000349#endif // M0 > 1
350#if M0 > 2
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000351 BOUNDARY_CONDITION_X(x, a2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000352#endif // M0 > 2
353#if M0 > 3
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000354 BOUNDARY_CONDITION_X(x, a3);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000355#endif // M0 > 3
356#if M0 > 4
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000357 BOUNDARY_CONDITION_X(x, a4);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000358#endif // M0 > 4
359#if M0 > 5
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000360 BOUNDARY_CONDITION_X(x, a5);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000361#endif // M0 > 5
362#if M0 > 6
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000363 BOUNDARY_CONDITION_X(x, a6);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000364#endif // M0 > 6
365#if M0 > 7
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000366 BOUNDARY_CONDITION_X(x, a7);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000367#endif // M0 > 7
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000368 // ---------------------------Transpose and store block -----------------------
369
370 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 0);
371 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 1);
372#if K0 > 2
373 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 2);
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000374#endif // K0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000375#if K0 > 3
376 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 3);
377#endif // K0 > 3
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000378#if K0 > 4
379 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 4);
380 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 5);
381 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 6);
382 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 7);
383#endif // K0 > 4
384#if K0 > 8
385 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 8);
386 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, 9);
387 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, A);
388 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, B);
389 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, C);
390 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, D);
391 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, E);
392 TRANSPOSE_COLUMN_AND_STORE(output_ptr, OUTPUT_STEP_X, F);
393#endif // K0 > 8
394
395#undef BLOCK_SIZE
396#undef OUTPUT_OFFSET_X
397#undef OUTPUT_STEP_X
398}
Gian Marco Iodiceb87b95e2019-01-21 17:14:31 +0000399#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE) && defined(SRC_WIDTH)
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000400
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000401#if defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
402/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (not transposed) in
403 * the output matrix unrolling the values.
404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100405 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
406 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
407 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
408 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000409 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
410 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000411 * N0: 2,3,4,8,16
412 * K0: 1,2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000413 * H0: greater than 0
414 *
415 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
416 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
417 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
418 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
419 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
420 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
421 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
422 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
423 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
424 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
425 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
426 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
427 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
428 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
429 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
430 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
431 */
432__kernel void gemm_reshape_rhs_matrix_nt(TENSOR3D_DECLARATION(src),
433 TENSOR3D_DECLARATION(dst))
434{
435 // Block size
436#define BLOCK_SIZE ((K0) * (N0))
437
438 // Output offset X
439#if defined(INTERLEAVE)
440#define OUTPUT_OFFSET_X (N0)
441#else // defined(INTERLEAVE)
442#define OUTPUT_OFFSET_X (BLOCK_SIZE)
443#endif // defined(INTERLEAVE)
444
445 // Output step X
446#if defined(INTERLEAVE)
447#define OUTPUT_STEP_X (N0) * (H0)
448#else // Do not interleave
449#define OUTPUT_STEP_X (N0)
450#endif // defined(INTERLEAVE)
451
452 // Compute source and destination addresses
453 uint x = get_global_id(0);
454 uint y = get_global_id(1);
455 uint z = get_global_id(2);
456
457 // ------------------ Compute input/output addresses ---------------------------
458
459 // Compute the input address
460 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
461
462 // Compute the output address
463 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % (uint)H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((
464 x / (uint)H0)
465 * (uint)dst_stride_y)
466 + z * (uint)dst_stride_z;
467
468 // ---------------------------Load input values --------------------------------
469
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000470 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); ////uint a0=0, a1=0, a2=0...a(M0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000471
472 // Load values from the RHS matrix
473 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
474#if K0 > 1
475 if(y * (uint)K0 + 1 < SRC_HEIGHT)
476 {
477 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
478 }
479#endif // K0 > 1
480#if K0 > 2
481 if(y * (uint)K0 + 2 < SRC_HEIGHT)
482 {
483 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
484 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000485#endif // K0 > 2
486#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000487 if(y * (uint)K0 + 3 < SRC_HEIGHT)
488 {
489 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
490 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000491#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000492#if K0 > 4
493 if(y * (uint)K0 + 4 < SRC_HEIGHT)
494 {
495 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
496 }
497 if(y * (uint)K0 + 5 < SRC_HEIGHT)
498 {
499 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
500 }
501 if(y * (uint)K0 + 6 < SRC_HEIGHT)
502 {
503 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
504 }
505 if(y * (uint)K0 + 7 < SRC_HEIGHT)
506 {
507 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
508 }
509#endif // K0 > 4
510#if K0 > 8
Gian Marco Iodice08ddd7b2018-12-19 10:01:18 +0000511 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000512 {
513 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
514 }
515 if(y * (uint)K0 + 9 < SRC_HEIGHT)
516 {
517 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
518 }
519 if(y * (uint)K0 + 10 < SRC_HEIGHT)
520 {
521 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
522 }
523 if(y * (uint)K0 + 11 < SRC_HEIGHT)
524 {
525 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
526 }
527 if(y * (uint)K0 + 12 < SRC_HEIGHT)
528 {
529 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
530 }
531 if(y * (uint)K0 + 13 < SRC_HEIGHT)
532 {
533 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
534 }
535 if(y * (uint)K0 + 14 < SRC_HEIGHT)
536 {
537 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
538 }
539 if(y * (uint)K0 + 15 < SRC_HEIGHT)
540 {
541 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
542 }
543#endif // K0 > 8
544
545 // ---------------------------Store output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100546 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
547 STORE_BLOCK(K0, N0, DATA_TYPE, a, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000548
549#undef BLOCK_SIZE
550#undef OUTPUT_OFFSET_X
551#undef OUTPUT_STEP_X
552}
553
554#if defined(TRANSPOSE)
555/** This OpenCL kernel reshapes the rhs input matrix. The kernel splits the input matrix in blocks of size K0xN0 and stores each one (transposed) in
556 * the output matrix unrolling the values.
557 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +0100558 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
559 * @note The height of the input tensor must be passed at compile time using -DSRC_HEIGHT (e.g. -DSRC_HEIGHT=16)
560 * @note The block's dimensions (K0 and N0) must be passed at compile time using -DK0 and -DN0 (e.g. -DK0=2, -DN0=2).
561 * @note The number of K0xN0 vertical blocks to store on the same output row must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000562 * @note If the K0xN0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
563 * @note The option -DTRANSPOSE must passed at compile time.
564 * @note Only the following values for K0, N0 and H0 are supported:
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000565 * N0: 2,3,4,8,16
566 * K0: 2,3,4,8,16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000567 * H0: greater than 0
568 *
569 * @param[in] src_ptr Pointer to the source RHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
570 * @param[in] src_stride_x Stride of the source RHS tensor in X dimension (in bytes)
571 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
572 * @param[in] src_stride_y Stride of the source RHS tensor in Y dimension (in bytes)
573 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
574 * @param[in] src_stride_z Stride of the source RHS tensor in Z dimension (in bytes)
575 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
576 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source RHS tensor
577 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
578 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
579 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
580 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
581 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
582 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
583 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
584 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
585 */
586__kernel void gemm_reshape_rhs_matrix_t(TENSOR3D_DECLARATION(src),
587 TENSOR3D_DECLARATION(dst))
588{
589 // Block size
590#define BLOCK_SIZE ((K0) * (N0))
591
592 // Output offset X
593#if defined(INTERLEAVE)
594#define OUTPUT_OFFSET_X (K0)
595#else // defined(INTERLEAVE)
596#define OUTPUT_OFFSET_X (BLOCK_SIZE)
597#endif // defined(INTERLEAVE)
598
599 // Output step X
600#if defined(INTERLEAVE)
601#define OUTPUT_STEP_X (K0) * (H0)
602#else // Do not interleave
603#define OUTPUT_STEP_X (K0)
604#endif // defined(INTERLEAVE)
605
606 // Compute source and destination addresses
607 uint x = get_global_id(0);
608 uint y = get_global_id(1);
609 uint z = get_global_id(2);
610
611 // ------------------ Compute input/output addresses ---------------------------
612
613 // Compute the input address
614 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)N0 * sizeof(DATA_TYPE) + y * (uint)K0 * src_stride_y + z * (uint)src_stride_z;
615
616 // Compute the output address
617 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (y * (uint)BLOCK_SIZE * (uint)H0 * sizeof(DATA_TYPE)) + ((x % H0) * (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE)) + ((x /
618 (uint)H0) * (uint)dst_stride_y) + z * (uint)dst_stride_z;
619
620 // ---------------------------Load input values --------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000621 REPEAT_VAR_INIT_TO_CONST(K0, VEC_DATA_TYPE(DATA_TYPE, N0), a, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) a0=0, a1=0, ... a(K0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000622
623 // Load values from the RHS matrix
624 a0 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
625 if(y * (uint)K0 + 1 < SRC_HEIGHT)
626 {
627 a1 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
628 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000629#if K0 > 2
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000630 if(y * (uint)K0 + 2 < SRC_HEIGHT)
631 {
632 a2 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
633 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000634#endif // K0 > 2
635#if K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000636 if(y * (uint)K0 + 3 < SRC_HEIGHT)
637 {
638 a3 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
639 }
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000640#endif // K0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000641#if K0 > 4
642 if(y * (uint)K0 + 4 < SRC_HEIGHT)
643 {
644 a4 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y));
645 }
646 if(y * (uint)K0 + 5 < SRC_HEIGHT)
647 {
648 a5 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y));
649 }
650 if(y * (uint)K0 + 6 < SRC_HEIGHT)
651 {
652 a6 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y));
653 }
654 if(y * (uint)K0 + 7 < SRC_HEIGHT)
655 {
656 a7 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y));
657 }
658#endif // K0 > 4
659#if K0 > 8
Gian Marco Iodice89124342018-12-19 14:17:22 +0000660 if(y * (uint)K0 + 8 < SRC_HEIGHT)
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000661 {
662 a8 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 8 * src_stride_y));
663 }
664 if(y * (uint)K0 + 9 < SRC_HEIGHT)
665 {
666 a9 = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 9 * src_stride_y));
667 }
668 if(y * (uint)K0 + 10 < SRC_HEIGHT)
669 {
670 aA = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 10 * src_stride_y));
671 }
672 if(y * (uint)K0 + 11 < SRC_HEIGHT)
673 {
674 aB = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 11 * src_stride_y));
675 }
676 if(y * (uint)K0 + 12 < SRC_HEIGHT)
677 {
678 aC = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 12 * src_stride_y));
679 }
680 if(y * (uint)K0 + 13 < SRC_HEIGHT)
681 {
682 aD = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 13 * src_stride_y));
683 }
684 if(y * (uint)K0 + 14 < SRC_HEIGHT)
685 {
686 aE = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 14 * src_stride_y));
687 }
688 if(y * (uint)K0 + 15 < SRC_HEIGHT)
689 {
690 aF = VLOAD(N0)(0, (__global DATA_TYPE *)(input_ptr + 15 * src_stride_y));
691 }
692#endif // K0 > 8
693
694 // ---------------------------Transpose the block ------------------------------
Vidhya Sudhan Loganathan17b0f8b2019-01-08 12:17:03 +0000695 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), res, 0); //VEC_DATA_TYPE(DATA_TYPE, K0) res0=0, res1=0, res2=0,... res(N0-1)=0;
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000696
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000697#if K0 == 2
698 // This part computes the following transpositions:
699 // 2x2 -> 2x2
700 // 2x4 -> 4x2
701 // 2x8 -> 8x2
702 // 2x16 -> 16x2
703 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0);
704 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1);
705#if N0 > 2
706 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2);
707#endif // N0 > 2
708#if N0 > 3
709 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3);
710#endif // N0 > 3
711#if N0 > 4
712 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4);
713 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5);
714 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6);
715 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7);
716#endif // N0 > 4
717#if N0 > 8
718 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8);
719 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9);
720 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA);
721 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB);
722 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC);
723 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD);
724 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE);
725 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF);
726#endif // N0 > 8
727
728#elif K0 == 3 // K0 == 2
729 // This part computes the following transpositions:
730 // 3x2 -> 2x3
731 // 3x4 -> 4x3
732 // 3x8 -> 8x3
733 // 3x16 -> 16x3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100734 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0);
735 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000736#if N0 > 2
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100737 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000738#endif // N0 > 2
739#if N0 > 3
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100740 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000741#endif // N0 > 3
742#if N0 > 4
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100743 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4);
744 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5);
745 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6);
746 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000747#endif // N0 > 4
748#if N0 > 8
Georgios Pinitasb0f342e2019-05-21 13:32:43 +0100749 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8);
750 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9);
751 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA);
752 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB);
753 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC);
754 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD);
755 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE);
756 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000757#endif // N0 > 8
758
759#elif K0 == 4 // K0 == 4
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000760 // This part computes the following transpositions:
761 // 4x2 -> 2x4
762 // 4x4 -> 4x4
763 // 4x8 -> 8x4
764 // 4x16 -> 16x4
765 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0);
766 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1);
767#if N0 > 2
768 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000769#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000770#if N0 > 3
771 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3);
772#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000773#if N0 > 4
774 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4);
775 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5);
776 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6);
777 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7);
778#endif // N0 > 4
779#if N0 > 8
780 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8);
781 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9);
782 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA);
783 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB);
784 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC);
785 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD);
786 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE);
787 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF);
788#endif // N0 > 8
789
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000790#elif K0 == 8 // K0 == 8
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000791 // This part computes the following transpositions:
792 // 8x2 -> 2x8
793 // 8x4 -> 4x8
794 // 8x8 -> 8x8
795 // 8x16 -> 16x8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000796 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0);
797 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000798#if N0 > 2
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000799 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000800#endif // N0 > 2
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000801#if N0 > 3
802 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3);
803#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000804#if N0 > 4
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000805 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4);
806 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5);
807 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6);
808 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000809#endif // N0 > 4
810#if N0 > 8
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +0000811 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8);
812 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9);
813 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA);
814 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB);
815 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC);
816 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD);
817 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE);
818 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000819#endif // N0 > 8
820
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000821#elif K0 == 16 // K0 == 16
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000822
823 // This part computes the following transpositions:
824 // 16x2 -> 2x16
825 // 16x4 -> 4x16
826 // 16x8 -> 8x16
827 // 16x16 -> 16x16
828 res0 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s0, a1.s0, a2.s0, a3.s0, a4.s0, a5.s0, a6.s0, a7.s0,
829 a8.s0, a9.s0, aA.s0, aB.s0, aC.s0, aD.s0, aE.s0, aF.s0);
830 res1 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s1, a1.s1, a2.s1, a3.s1, a4.s1, a5.s1, a6.s1, a7.s1,
831 a8.s1, a9.s1, aA.s1, aB.s1, aC.s1, aD.s1, aE.s1, aF.s1);
832#if N0 > 2
833 res2 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s2, a1.s2, a2.s2, a3.s2, a4.s2, a5.s2, a6.s2, a7.s2,
834 a8.s2, a9.s2, aA.s2, aB.s2, aC.s2, aD.s2, aE.s2, aF.s2);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000835#endif // N0 > 2
836#if N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000837 res3 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s3, a1.s3, a2.s3, a3.s3, a4.s3, a5.s3, a6.s3, a7.s3,
838 a8.s3, a9.s3, aA.s3, aB.s3, aC.s3, aD.s3, aE.s3, aF.s3);
Gian Marco Iodicebacfec52019-01-11 11:30:55 +0000839#endif // N0 > 3
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000840#if N0 > 4
841 res4 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s4, a1.s4, a2.s4, a3.s4, a4.s4, a5.s4, a6.s4, a7.s4,
842 a8.s4, a9.s4, aA.s4, aB.s4, aC.s4, aD.s4, aE.s4, aF.s4);
843 res5 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s5, a1.s5, a2.s5, a3.s5, a4.s5, a5.s5, a6.s5, a7.s5,
844 a8.s5, a9.s5, aA.s5, aB.s5, aC.s5, aD.s5, aE.s5, aF.s5);
845 res6 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s6, a1.s6, a2.s6, a3.s6, a4.s6, a5.s6, a6.s6, a7.s6,
846 a8.s6, a9.s6, aA.s6, aB.s6, aC.s6, aD.s6, aE.s6, aF.s6);
847 res7 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s7, a1.s7, a2.s7, a3.s7, a4.s7, a5.s7, a6.s7, a7.s7,
848 a8.s7, a9.s7, aA.s7, aB.s7, aC.s7, aD.s7, aE.s7, aF.s7);
849#endif // N0 > 4
850#if N0 > 8
851 res8 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s8, a1.s8, a2.s8, a3.s8, a4.s8, a5.s8, a6.s8, a7.s8,
852 a8.s8, a9.s8, aA.s8, aB.s8, aC.s8, aD.s8, aE.s8, aF.s8);
853 res9 = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.s9, a1.s9, a2.s9, a3.s9, a4.s9, a5.s9, a6.s9, a7.s9,
854 a8.s9, a9.s9, aA.s9, aB.s9, aC.s9, aD.s9, aE.s9, aF.s9);
855 resA = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sA, a1.sA, a2.sA, a3.sA, a4.sA, a5.sA, a6.sA, a7.sA,
856 a8.sA, a9.sA, aA.sA, aB.sA, aC.sA, aD.sA, aE.sA, aF.sA);
857 resB = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sB, a1.sB, a2.sB, a3.sB, a4.sB, a5.sB, a6.sB, a7.sB,
858 a8.sB, a9.sB, aA.sB, aB.sB, aC.sB, aD.sB, aE.sB, aF.sB);
859 resC = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sC, a1.sC, a2.sC, a3.sC, a4.sC, a5.sC, a6.sC, a7.sC,
860 a8.sC, a9.sC, aA.sC, aB.sC, aC.sC, aD.sC, aE.sC, aF.sC);
861 resD = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sD, a1.sD, a2.sD, a3.sD, a4.sD, a5.sD, a6.sD, a7.sD,
862 a8.sD, a9.sD, aA.sD, aB.sD, aC.sD, aD.sD, aE.sD, aF.sD);
863 resE = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sE, a1.sE, a2.sE, a3.sE, a4.sE, a5.sE, a6.sE, a7.sE,
864 a8.sE, a9.sE, aA.sE, aB.sE, aC.sE, aD.sE, aE.sE, aF.sE);
865 resF = (VEC_DATA_TYPE(DATA_TYPE, K0))(a0.sF, a1.sF, a2.sF, a3.sF, a4.sF, a5.sF, a6.sF, a7.sF,
866 a8.sF, a9.sF, aA.sF, aB.sF, aC.sF, aD.sF, aE.sF, aF.sF);
867#endif // N0 > 8
868
869#else // N0 == 16
870#error "Not supported N0 value"
871#endif // N0 > 2
872
873 // ---------------------------Store the output values ------------------------------
Usama Arif0681e3b2019-04-25 14:28:07 +0100874 REPEAT_VAR_INIT_TO_CONST(16, uint, zout, 0);
875 STORE_BLOCK(N0, K0, DATA_TYPE, res, output_ptr, OUTPUT_STEP_X * sizeof(DATA_TYPE), zout);
Gian Marco Iodice3b0a2652018-12-07 11:18:09 +0000876
877#undef BLOCK_SIZE
878#undef OUTPUT_OFFSET_X
879#undef OUTPUT_STEP_X
880}
881#endif // defined(TRANSPOSE)
882#endif // defined(K0) && defined(N0) && defined(H0) && defined(DATA_TYPE) && defined(SRC_HEIGHT)
883
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +0000884#if defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +0000885
886#define CONCAT(a, b) a##b
887
888#define ARM_DOT1(a, b, c) \
889 ({ \
890 c = fma(a, b, c); \
891 })
892#define ARM_DOT2(a, b, c) \
893 ({ \
894 c = fma(a.s0, b.s0, c); \
895 c = fma(a.s1, b.s1, c); \
896 })
897#define ARM_DOT3(a, b, c) \
898 ({ \
899 ARM_DOT2(a, b, c); \
900 c = fma((a.s2), (b.s2), c); \
901 })
902#define ARM_DOT4(a, b, c) \
903 ({ \
904 ARM_DOT3(a, b, c); \
905 c = fma((a.s3), (b.s3), c); \
906 })
907#define ARM_DOT8(a, b, c) \
908 ({ \
909 ARM_DOT4((a.lo), (b.lo), c); \
910 ARM_DOT4((a.hi), (b.hi), c); \
911 })
912#define ARM_DOT16(a, b, c) \
913 ({ \
914 ARM_DOT8((a.lo), (b.lo), c); \
915 ARM_DOT8((a.hi), (b.hi), c); \
916 })
917
918#if N0 == 2
919#define ARM_DOT_K0XN0(k0, a, b, c) \
920 ({ \
921 CONCAT(ARM_DOT, k0) \
922 ((a), (b##0), (c.s0)); \
923 CONCAT(ARM_DOT, k0) \
924 ((a), (b##1), (c.s1)); \
925 })
926#elif N0 == 3 // N0 == 3
927#define ARM_DOT_K0XN0(k0, a, b, c) \
928 ({ \
929 CONCAT(ARM_DOT, k0) \
930 ((a), (b##0), (c.s0)); \
931 CONCAT(ARM_DOT, k0) \
932 ((a), (b##1), (c.s1)); \
933 CONCAT(ARM_DOT, k0) \
934 ((a), (b##2), (c.s2)); \
935 })
936#elif N0 == 4 // N0 == 4
937#define ARM_DOT_K0XN0(k0, a, b, c) \
938 ({ \
939 CONCAT(ARM_DOT, k0) \
940 ((a), (b##0), (c.s0)); \
941 CONCAT(ARM_DOT, k0) \
942 ((a), (b##1), (c.s1)); \
943 CONCAT(ARM_DOT, k0) \
944 ((a), (b##2), (c.s2)); \
945 CONCAT(ARM_DOT, k0) \
946 ((a), (b##3), (c.s3)); \
947 })
948#elif N0 == 8 // N0 == 8
949#define ARM_DOT_K0XN0(k0, a, b, c) \
950 ({ \
951 CONCAT(ARM_DOT, k0) \
952 ((a), (b##0), (c.s0)); \
953 CONCAT(ARM_DOT, k0) \
954 ((a), (b##1), (c.s1)); \
955 CONCAT(ARM_DOT, k0) \
956 ((a), (b##2), (c.s2)); \
957 CONCAT(ARM_DOT, k0) \
958 ((a), (b##3), (c.s3)); \
959 CONCAT(ARM_DOT, k0) \
960 ((a), (b##4), (c.s4)); \
961 CONCAT(ARM_DOT, k0) \
962 ((a), (b##5), (c.s5)); \
963 CONCAT(ARM_DOT, k0) \
964 ((a), (b##6), (c.s6)); \
965 CONCAT(ARM_DOT, k0) \
966 ((a), (b##7), (c.s7)); \
967 })
968#elif N0 == 16 // N0 == 16
969#define ARM_DOT_K0XN0(k0, a, b, c) \
970 ({ \
971 CONCAT(ARM_DOT, k0) \
972 ((a), (b##0), (c.s0)); \
973 CONCAT(ARM_DOT, k0) \
974 ((a), (b##1), (c.s1)); \
975 CONCAT(ARM_DOT, k0) \
976 ((a), (b##2), (c.s2)); \
977 CONCAT(ARM_DOT, k0) \
978 ((a), (b##3), (c.s3)); \
979 CONCAT(ARM_DOT, k0) \
980 ((a), (b##4), (c.s4)); \
981 CONCAT(ARM_DOT, k0) \
982 ((a), (b##5), (c.s5)); \
983 CONCAT(ARM_DOT, k0) \
984 ((a), (b##6), (c.s6)); \
985 CONCAT(ARM_DOT, k0) \
986 ((a), (b##7), (c.s7)); \
987 CONCAT(ARM_DOT, k0) \
988 ((a), (b##8), (c.s8)); \
989 CONCAT(ARM_DOT, k0) \
990 ((a), (b##9), (c.s9)); \
991 CONCAT(ARM_DOT, k0) \
992 ((a), (b##A), (c.sA)); \
993 CONCAT(ARM_DOT, k0) \
994 ((a), (b##B), (c.sB)); \
995 CONCAT(ARM_DOT, k0) \
996 ((a), (b##C), (c.sC)); \
997 CONCAT(ARM_DOT, k0) \
998 ((a), (b##D), (c.sD)); \
999 CONCAT(ARM_DOT, k0) \
1000 ((a), (b##E), (c.sE)); \
1001 CONCAT(ARM_DOT, k0) \
1002 ((a), (b##F), (c.sF)); \
1003 })
1004#else // N0 not supported
1005#error "N0 value not supported"
1006#endif // N0 conditions
1007
1008/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1009 * The LHS matrix is NOT reshaped
1010 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
1011 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001012 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001013 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
1014 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
1015 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1016 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1017 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001018 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1019 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1020 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1021 * - N0 = 2, 3, 4, 8, 16
1022 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001023 * - H0 >= 1
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001024 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001025 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001026 * The activation function is performed after the bias addition
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001027 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1028 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1029 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1030 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1031 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1032 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1033 *
Sheri Zhang1a378102020-04-30 12:59:39 +01001034 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
1035 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001036 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001037 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001038 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001039 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001040 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1041 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1042 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1043 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1044 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1045 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001046 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1047 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1048 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1049 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1050 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1051 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001052 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1053 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1054 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1055 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1056 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1057 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Sheri Zhang1a378102020-04-30 12:59:39 +01001058 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001059 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001060 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001061 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1062 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1063 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001064 */
1065__kernel void gemm_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
1066 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001067#if defined(BETA)
1068 IMAGE_DECLARATION(bias),
1069#endif // defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001070 IMAGE_DECLARATION(dst),
1071 uint lhs_stride_z,
1072 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001073#if defined(BETA)
1074 uint bias_stride_z,
1075#endif //defined(BETA)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001076 uint dst_stride_z
1077#if defined(REINTERPRET_INPUT_AS_3D)
1078 ,
1079 uint lhs_cross_plane_pad
1080#endif // REINTERPRET_INPUT_AS_3D
1081#if defined(REINTERPRET_OUTPUT_AS_3D)
1082 ,
1083 uint dst_cross_plane_pad
1084#endif // REINTERPRET_OUTPUT_AS_3D
1085 )
1086{
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001087 // Block size
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001088#define RHS_BLOCK_SIZE ((K0) * (N0))
1089
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001090 // RHS offset and step X
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001091#if defined(RHS_INTERLEAVE)
1092#define RHS_OFFSET_X (K0)
1093#define RHS_STEP_X ((K0) * (H0))
1094#define RHS_STEP_LOOP (1)
1095#else // defined(RHS_INTERLEAVE)
1096#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1097#define RHS_STEP_X (K0)
1098#define RHS_STEP_LOOP (H0)
1099#endif // defined(RHS_INTERLEAVE)
1100
1101 uint x = get_global_id(0);
1102 uint y = get_global_id(1);
1103 uint z = get_global_id(2);
1104
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001105#if defined(DUMMY_WORK_ITEMS)
1106 if((x * N0 >= N) || (y * M0 >= M))
1107 {
1108 return;
1109 }
1110#endif // defined(DUMMY_WORK_ITEMS)
1111
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001112 // Compute LHS matrix address
1113 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1114
Sheri Zhang1a378102020-04-30 12:59:39 +01001115 // Compute RHS reshaped matrix address
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001116 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1117
1118#if defined(MATRIX_B_DEPTH)
1119 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1120 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1121#else // defined(MATRIX_B_DEPTH)
1122 rhs_offset += z * rhs_stride_z;
1123#endif // defined(MATRIX_B_DEPTH)
1124
Usama Arif0681e3b2019-04-25 14:28:07 +01001125 REPEAT_VAR_INIT_TO_CONST(8, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001126 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001127
1128#if defined(REINTERPRET_INPUT_AS_3D)
Usama Arif0681e3b2019-04-25 14:28:07 +01001129 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
1130 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001131
1132 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1133 // multiply lhs_stride_z by DEPTH_GEMM3D
1134 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1135
1136#else // defined(REINTERPRET_INPUT_AS_3D)
1137
1138 // Add offset for batched GEMM
1139 lhs_offset += z * lhs_stride_z;
1140
1141#endif // defined(REINTERPRET_INPUT_AS_3D)
1142
1143 // Initialize the accumulators
1144 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
1145
1146 int i = 0;
1147 for(; i <= (K - K0); i += K0)
1148 {
1149 // Supported cases (M0, K0):
1150 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1151 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1152 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1153 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1154 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1155 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1156 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1157 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1158 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001159 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001160
Sheri Zhang1a378102020-04-30 12:59:39 +01001161 // Load values from RHS reshaped matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001162 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001163
1164 // Accumulate
1165 ARM_DOT_K0XN0(K0, a0, b, c0);
1166#if M0 > 1
1167 ARM_DOT_K0XN0(K0, a1, b, c1);
1168#endif // M0 > 1
1169#if M0 > 2
1170 ARM_DOT_K0XN0(K0, a2, b, c2);
1171#endif // M0 > 2
1172#if M0 > 3
1173 ARM_DOT_K0XN0(K0, a3, b, c3);
1174#endif // M0 > 3
1175#if M0 > 4
1176 ARM_DOT_K0XN0(K0, a4, b, c4);
1177#endif // M0 > 4
1178#if M0 > 5
1179 ARM_DOT_K0XN0(K0, a5, b, c5);
1180#endif // M0 > 5
1181#if M0 > 6
1182 ARM_DOT_K0XN0(K0, a6, b, c6);
1183#endif // M0 > 6
1184#if M0 > 7
1185 ARM_DOT_K0XN0(K0, a7, b, c7);
1186#endif // M0 > 7
1187
1188 lhs_offset += K0 * sizeof(DATA_TYPE);
1189 rhs_offset += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
1190 }
1191
1192 // Left-over accumulations
1193 for(; i < K; ++i)
1194 {
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001195 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001196 LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001197
Sheri Zhang1a378102020-04-30 12:59:39 +01001198 // Load values from RHS reshaped matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001199 LOAD_BLOCK(N0, 1, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001200
1201 // Accumulate
1202 ARM_DOT_K0XN0(1, a0, b, c0);
1203#if M0 > 1
1204 ARM_DOT_K0XN0(1, a1, b, c1);
1205#endif // M0 > 1
1206#if M0 > 2
1207 ARM_DOT_K0XN0(1, a2, b, c2);
1208#endif // M0 > 2
1209#if M0 > 3
1210 ARM_DOT_K0XN0(1, a3, b, c3);
1211#endif // M0 > 3
1212#if M0 > 4
1213 ARM_DOT_K0XN0(1, a4, b, c4);
1214#endif // M0 > 4
1215#if M0 > 5
1216 ARM_DOT_K0XN0(1, a5, b, c5);
1217#endif // M0 > 5
1218#if M0 > 6
1219 ARM_DOT_K0XN0(1, a6, b, c6);
1220#endif // M0 > 6
1221#if M0 > 7
1222 ARM_DOT_K0XN0(1, a7, b, c7);
1223#endif // M0 > 7
1224
1225 lhs_offset += sizeof(DATA_TYPE);
1226 rhs_offset += sizeof(DATA_TYPE);
1227 }
1228
1229 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1230
1231 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1232
1233#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001234
1235 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001236 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001237
1238 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1239 // multiply dst_stride_z by DEPTH_GEMM3D
1240 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1241
1242#else // defined(REINTERPRET_OUTPUT_AS_3D)
1243
1244 // Add offset for batched GEMM
1245 dst_addr += z * dst_stride_z;
1246
1247#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1248
1249 // Multiply by the weight of matrix-matrix product and store the result
1250#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001251 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001252#endif // defined(ALPHA)
1253
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001254 // Add beta*bias
1255#if defined(BETA)
1256#if defined(BROADCAST_BIAS)
1257 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1258
1259 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1260
1261#ifndef UNIT_BETA
1262 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1263#endif // UNIT_BIAS
1264
1265 // c = c + bias[broadcasted]
1266 ADD_BLOCK_BROADCAST(M0, c, bias0);
1267
1268#else // defined(BROADCAST_BIAS)
1269 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1270 2) * bias_stride_z;
1271
1272 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1273
1274#ifndef UNIT_BETA
1275 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1276#endif // UNIT_BIAS
1277
1278 // c = c + bias
1279 ADD_BLOCK(M0, c, bias);
1280
1281#endif // defined(BROADCAST_BIAS)
1282#endif // defined(BETA)
1283
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001284#if defined(ACTIVATION_TYPE)
1285 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1286#endif // defined(ACTIVATION_TYPE)
1287
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001288 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001289 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001290
1291#undef RHS_BLOCK_SIZE
1292#undef RHS_OFFSET_X
1293#undef RHS_STEP_X
1294}
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001295
1296#define VFMA(a, b, c) \
1297 ({ \
1298 c = fma(a, b, c); \
1299 })
1300
1301#if M0 == 1
1302#define LD_RHS_VFMA_M0xN0(i, a, c) \
1303 ({ \
1304 VEC_DATA_TYPE(DATA_TYPE, N0) \
1305 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1306 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1307 })
1308#elif M0 == 2 // M0 == 2
1309#define LD_RHS_VFMA_M0xN0(i, a, c) \
1310 ({ \
1311 VEC_DATA_TYPE(DATA_TYPE, N0) \
1312 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1313 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1314 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1315 })
1316#elif M0 == 3 // M0 == 3
1317#define LD_RHS_VFMA_M0xN0(i, a, c) \
1318 ({ \
1319 VEC_DATA_TYPE(DATA_TYPE, N0) \
1320 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1321 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1322 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1323 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1324 })
1325#elif M0 == 4 // M0 == 4
1326#define LD_RHS_VFMA_M0xN0(i, a, c) \
1327 ({ \
1328 VEC_DATA_TYPE(DATA_TYPE, N0) \
1329 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1330 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1331 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1332 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1333 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1334 })
1335#elif M0 == 5 // M0 == 5
1336#define LD_RHS_VFMA_M0xN0(i, a, c) \
1337 ({ \
1338 VEC_DATA_TYPE(DATA_TYPE, N0) \
1339 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1340 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1341 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1342 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1343 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1344 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1345 })
1346#elif M0 == 6 // M0 == 6
1347#define LD_RHS_VFMA_M0xN0(i, a, c) \
1348 ({ \
1349 VEC_DATA_TYPE(DATA_TYPE, N0) \
1350 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1351 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1352 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1353 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1354 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1355 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1356 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1357 })
1358#elif M0 == 7 // M0 == 7
1359#define LD_RHS_VFMA_M0xN0(i, a, c) \
1360 ({ \
1361 VEC_DATA_TYPE(DATA_TYPE, N0) \
1362 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1363 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1364 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1365 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1366 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1367 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1368 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1369 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1370 })
1371#elif M0 == 8 // M0 == 8
1372#define LD_RHS_VFMA_M0xN0(i, a, c) \
1373 ({ \
1374 VEC_DATA_TYPE(DATA_TYPE, N0) \
1375 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0x##i * RHS_STEP_X * sizeof(DATA_TYPE))); \
1376 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
1377 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
1378 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
1379 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
1380 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
1381 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
1382 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
1383 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
1384 })
1385#else // M0 not supported
1386#error "M0 not supported"
1387#endif // M0 not supported
1388
1389/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1390 * The LHS matrix is NOT reshaped
1391 * The RHS is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is NOT transposed
1392 *
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001393 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001394 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90).
1395 * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (e.g. -DN0=8, -DK0=4).
1396 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
1397 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001398 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1399 * @note Only the following configurations of M0, N0 and K0 are currently supported:
1400 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
1401 * - N0 = 2, 3, 4, 8, 16
1402 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001403 * - H0 >= 1
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001404 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001405 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001406 * The activation function is performed after the bias addition
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001407 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1408 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1409 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1410 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1411 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1412 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
1413 *
Sheri Zhang1a378102020-04-30 12:59:39 +01001414 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
1415 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001416 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001417 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001418 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Sheri Zhang1a378102020-04-30 12:59:39 +01001419 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001420 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1421 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1422 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1423 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1424 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1425 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001426 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1427 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001428 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001429 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001430 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1431 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1432 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1433 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1434 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1435 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1436 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1437 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Sheri Zhang1a378102020-04-30 12:59:39 +01001438 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001439 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001440 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001441 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1442 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
1443 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001444 */
1445__kernel void gemm_mm_reshaped_only_rhs_nt(IMAGE_DECLARATION(lhs),
1446 IMAGE_DECLARATION(rhs),
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001447#if defined(BETA)
1448 IMAGE_DECLARATION(bias),
1449#endif // defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001450 IMAGE_DECLARATION(dst),
1451 uint lhs_stride_z,
1452 uint rhs_stride_z,
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001453#if defined(BETA)
1454 uint bias_stride_z,
1455#endif //defined(BETA)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001456 uint dst_stride_z
1457#if defined(REINTERPRET_INPUT_AS_3D)
1458 ,
1459 uint lhs_cross_plane_pad
1460#endif // REINTERPRET_INPUT_AS_3D
1461#if defined(REINTERPRET_OUTPUT_AS_3D)
1462 ,
1463 uint dst_cross_plane_pad
1464#endif // REINTERPRET_OUTPUT_AS_3D
1465 )
1466{
1467 // Block size
1468#define RHS_BLOCK_SIZE ((K0) * (N0))
1469
1470 // RHS offset and step X
1471#if defined(RHS_INTERLEAVE)
1472#define RHS_OFFSET_X (N0)
1473#define RHS_STEP_X ((N0) * (H0))
1474#define RHS_STEP_LOOP (1)
1475#else // defined(RHS_INTERLEAVE)
1476#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1477#define RHS_STEP_X (N0)
1478#define RHS_STEP_LOOP (H0)
1479#endif // defined(RHS_INTERLEAVE)
1480
1481 uint x = get_global_id(0);
1482 uint y = get_global_id(1);
1483 uint z = get_global_id(2);
1484
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001485#if defined(DUMMY_WORK_ITEMS)
1486 if((x * N0 >= N) || (y * M0 >= M))
1487 {
1488 return;
1489 }
1490#endif // defined(DUMMY_WORK_ITEMS)
1491
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001492 // Compute LHS matrix address
1493 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
1494
Sheri Zhang1a378102020-04-30 12:59:39 +01001495 // Compute RHS reshaped matrix address
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001496 uint rhs_offset = rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
1497
1498#if defined(MATRIX_B_DEPTH)
1499 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1500 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
1501#else // defined(MATRIX_B_DEPTH)
1502 rhs_offset += z * rhs_stride_z;
1503#endif // defined(MATRIX_B_DEPTH)
1504
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001505 REPEAT_VAR_INIT_TO_CONST(8, uint, zin, 0); //uint zin0=0,zin1=0,zin2=0,... zin7=0;
1506 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0); //uint zero0=0,zero1=0,zero2=0,... zero7=0;
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001507
1508#if defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001509
1510 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001511 CALCULATE_Z_OFFSET(M0, uint, zin, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001512
1513 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1514 // multiply lhs_stride_z by DEPTH_GEMM3D
1515 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
1516
1517#else // defined(REINTERPRET_INPUT_AS_3D)
1518
1519 // Add offset for batched GEMM
1520 lhs_offset += z * lhs_stride_z;
1521
1522#endif // defined(REINTERPRET_INPUT_AS_3D)
1523
1524 // Initialize the accumulators
1525 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
1526
1527 int i = 0;
1528 for(; i <= (K - K0); i += K0)
1529 {
1530 // Supported cases (M0, K0):
1531 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1532 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1533 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1534 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1535 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1536 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1537 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1538 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
1539 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001540 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zin);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001541
1542 LD_RHS_VFMA_M0xN0(0, a, c);
1543 LD_RHS_VFMA_M0xN0(1, a, c);
1544#if K0 > 2
1545 LD_RHS_VFMA_M0xN0(2, a, c);
1546#endif // K0 > 2
1547#if K0 > 3
1548 LD_RHS_VFMA_M0xN0(3, a, c);
1549#endif // K0 > 3
1550#if K0 > 4
1551 LD_RHS_VFMA_M0xN0(4, a, c);
1552 LD_RHS_VFMA_M0xN0(5, a, c);
1553 LD_RHS_VFMA_M0xN0(6, a, c);
1554 LD_RHS_VFMA_M0xN0(7, a, c);
1555#endif // K0 > 4
1556#if K0 > 8
1557 LD_RHS_VFMA_M0xN0(8, a, c);
1558 LD_RHS_VFMA_M0xN0(9, a, c);
1559 LD_RHS_VFMA_M0xN0(A, a, c);
1560 LD_RHS_VFMA_M0xN0(B, a, c);
1561 LD_RHS_VFMA_M0xN0(C, a, c);
1562 LD_RHS_VFMA_M0xN0(D, a, c);
1563 LD_RHS_VFMA_M0xN0(E, a, c);
1564 LD_RHS_VFMA_M0xN0(F, a, c);
1565#endif // K0 > 8
1566
1567 lhs_offset += K0 * sizeof(DATA_TYPE);
1568 rhs_offset += K0 * RHS_STEP_X * RHS_STEP_LOOP * sizeof(DATA_TYPE);
1569 }
1570
1571 // Left-over accumulations
1572 for(; i < K; ++i)
1573 {
1574 // Load values from LHS matrix
1575 VEC_DATA_TYPE(DATA_TYPE, 2)
1576 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zin0));
1577#if M0 > 1
1578 VEC_DATA_TYPE(DATA_TYPE, 2)
1579 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zin1));
1580#endif // M0 > 1
1581#if M0 > 2
1582 VEC_DATA_TYPE(DATA_TYPE, 2)
1583 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zin2));
1584#endif // M0 > 2
1585#if M0 > 3
1586 VEC_DATA_TYPE(DATA_TYPE, 2)
1587 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zin3));
1588#endif // M0 > 3
1589#if M0 > 4
1590 VEC_DATA_TYPE(DATA_TYPE, 2)
1591 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zin4));
1592#endif // M0 > 4
1593#if M0 > 5
1594 VEC_DATA_TYPE(DATA_TYPE, 2)
1595 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zin5));
1596#endif // M0 > 5
1597#if M0 > 6
1598 VEC_DATA_TYPE(DATA_TYPE, 2)
1599 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zin6));
1600#endif // M0 > 6
1601#if M0 > 7
1602 VEC_DATA_TYPE(DATA_TYPE, 2)
giuros01b3204e72019-04-01 13:50:22 +01001603 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zin7));
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001604#endif // M0 > 7
1605
1606 LD_RHS_VFMA_M0xN0(0, a, c);
1607
1608 lhs_offset += sizeof(DATA_TYPE);
1609 rhs_offset += RHS_STEP_X * sizeof(DATA_TYPE);
1610 }
1611
1612 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
1613
1614 REPEAT_VAR_INIT_TO_CONST(8, uint, zout, 0); //uint zout0=0,zout1=0,zout2=0,... zout7=0;
1615
1616#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001617 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01001618 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001619
1620 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1621 // multiply dst_stride_z by DEPTH_GEMM3D
1622 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1623
1624#else // defined(REINTERPRET_OUTPUT_AS_3D)
1625
1626 // Add offset for batched GEMM
1627 dst_addr += z * dst_stride_z;
1628
1629#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1630
1631 // Multiply by the weight of matrix-matrix product and store the result
1632#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01001633 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001634#endif // defined(ALPHA)
1635
Georgios Pinitasb0f342e2019-05-21 13:32:43 +01001636 // Add beta*bias
1637#if defined(BETA)
1638#if defined(BROADCAST_BIAS)
1639 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1640
1641 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1642
1643#ifndef UNIT_BETA
1644 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1645#endif // UNIT_BIAS
1646
1647 // c = c + bias[broadcasted]
1648 ADD_BLOCK_BROADCAST(M0, c, bias0);
1649
1650#else // defined(BROADCAST_BIAS)
1651 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
1652 2) * bias_stride_z;
1653
1654 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
1655
1656#ifndef UNIT_BETA
1657 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1658#endif // UNIT_BIAS
1659
1660 // c = c + bias
1661 ADD_BLOCK(M0, c, bias);
1662
1663#endif // defined(BROADCAST_BIAS)
1664#endif // defined(BETA)
1665
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001666#if defined(ACTIVATION_TYPE)
1667 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
1668#endif // defined(ACTIVATION_TYPE)
1669
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001670 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01001671 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodiceba5e0962019-03-11 12:17:44 +00001672
1673#undef RHS_BLOCK_SIZE
1674#undef RHS_OFFSET_X
1675#undef RHS_STEP_X
1676}
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001677#endif // defined(M0) && defined(N0) && defined(K0) && defined(H0) && defined(DATA_TYPE) && defined(M) && defined(N) && defined(K)
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001678
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001679#if defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(DATA_TYPE) && defined(DATA_TYPE_ACCUMULATOR) && defined(M) && defined(N)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001680
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001681#if defined(MIXED_PRECISION)
1682#if K0 == 2
1683#define ARM_DOT_K0(a, b, c) \
1684 ({ \
1685 c += a.s0 * b.s0; \
1686 c += a.s1 * b.s1; \
1687 })
1688#elif K0 == 3 // K0 == 3
1689#define ARM_DOT_K0(a, b, c) \
1690 ({ \
1691 c += a.s0 * b.s0; \
1692 c += a.s1 * b.s1; \
1693 c += a.s2 * b.s2; \
1694 })
1695#elif K0 == 4 // K0 == 4
1696#define ARM_DOT_K0(a, b, c) \
1697 ({ \
1698 c += a.s0 * b.s0; \
1699 c += a.s1 * b.s1; \
1700 c += a.s2 * b.s2; \
1701 c += a.s3 * b.s3; \
1702 })
1703#elif K0 == 8 // K0 == 8
1704#define ARM_DOT_K0(a, b, c) \
1705 ({ \
1706 c += a.s0 * b.s0; \
1707 c += a.s1 * b.s1; \
1708 c += a.s2 * b.s2; \
1709 c += a.s3 * b.s3; \
1710 c += a.s4 * b.s4; \
1711 c += a.s5 * b.s5; \
1712 c += a.s6 * b.s6; \
1713 c += a.s7 * b.s7; \
1714 })
1715#elif K0 == 16 // K0 == 16
1716#define ARM_DOT_K0(a, b, c) \
1717 ({ \
1718 c += a.s0 * b.s0; \
1719 c += a.s1 * b.s1; \
1720 c += a.s2 * b.s2; \
1721 c += a.s3 * b.s3; \
1722 c += a.s4 * b.s4; \
1723 c += a.s5 * b.s5; \
1724 c += a.s6 * b.s6; \
1725 c += a.s7 * b.s7; \
1726 c += a.s8 * b.s8; \
1727 c += a.s9 * b.s9; \
1728 c += a.sA * b.sA; \
1729 c += a.sB * b.sB; \
1730 c += a.sC * b.sC; \
1731 c += a.sD * b.sD; \
1732 c += a.sE * b.sE; \
1733 c += a.sF * b.sF; \
1734 })
1735#else // K0 not supported
1736#error "K0 value not supported"
1737#endif // K0 conditions
1738#else // defined(MIXED_PRECISION)
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001739#if K0 == 2
1740#define ARM_DOT_K0(a, b, c) \
1741 ({ \
1742 c = fma(a.s0, b.s0, c); \
1743 c = fma(a.s1, b.s1, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001744 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001745#elif K0 == 3 // K0 == 3
1746#define ARM_DOT_K0(a, b, c) \
1747 ({ \
1748 c = fma(a.s0, b.s0, c); \
1749 c = fma(a.s1, b.s1, c); \
1750 c = fma(a.s2, b.s2, c); \
1751 })
1752#elif K0 == 4 // K0 == 4
1753#define ARM_DOT_K0(a, b, c) \
1754 ({ \
1755 c = fma(a.s0, b.s0, c); \
1756 c = fma(a.s1, b.s1, c); \
1757 c = fma(a.s2, b.s2, c); \
1758 c = fma(a.s3, b.s3, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001759 })
1760#elif K0 == 8 // K0 == 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001761#define ARM_DOT_K0(a, b, c) \
1762 ({ \
1763 c = fma(a.s0, b.s0, c); \
1764 c = fma(a.s1, b.s1, c); \
1765 c = fma(a.s2, b.s2, c); \
1766 c = fma(a.s3, b.s3, c); \
1767 c = fma(a.s4, b.s4, c); \
1768 c = fma(a.s5, b.s5, c); \
1769 c = fma(a.s6, b.s6, c); \
1770 c = fma(a.s7, b.s7, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001771 })
1772#elif K0 == 16 // K0 == 16
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001773#define ARM_DOT_K0(a, b, c) \
1774 ({ \
1775 c = fma(a.s0, b.s0, c); \
1776 c = fma(a.s1, b.s1, c); \
1777 c = fma(a.s2, b.s2, c); \
1778 c = fma(a.s3, b.s3, c); \
1779 c = fma(a.s4, b.s4, c); \
1780 c = fma(a.s5, b.s5, c); \
1781 c = fma(a.s6, b.s6, c); \
1782 c = fma(a.s7, b.s7, c); \
1783 c = fma(a.s8, b.s8, c); \
1784 c = fma(a.s9, b.s9, c); \
1785 c = fma(a.sA, b.sA, c); \
1786 c = fma(a.sB, b.sB, c); \
1787 c = fma(a.sC, b.sC, c); \
1788 c = fma(a.sD, b.sD, c); \
1789 c = fma(a.sE, b.sE, c); \
1790 c = fma(a.sF, b.sF, c); \
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001791 })
1792#else // K0 not supported
1793#error "K0 value not supported"
1794#endif // K0 conditions
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001795#endif // defined(MIXED_PRECISION)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001796
1797#if N0 == 2
1798#define ARM_DOT_K0XN0(a, b, c) \
1799 ({ \
1800 ARM_DOT_K0((a), (b##0), (c.s0)); \
1801 ARM_DOT_K0((a), (b##1), (c.s1)); \
1802 })
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001803#elif N0 == 3 // N0 == 3
1804#define ARM_DOT_K0XN0(a, b, c) \
1805 ({ \
1806 ARM_DOT_K0((a), (b##0), (c.s0)); \
1807 ARM_DOT_K0((a), (b##1), (c.s1)); \
1808 ARM_DOT_K0((a), (b##2), (c.s2)); \
1809 })
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001810#elif N0 == 4 // N0 == 4
1811#define ARM_DOT_K0XN0(a, b, c) \
1812 ({ \
1813 ARM_DOT_K0((a), (b##0), (c.s0)); \
1814 ARM_DOT_K0((a), (b##1), (c.s1)); \
1815 ARM_DOT_K0((a), (b##2), (c.s2)); \
1816 ARM_DOT_K0((a), (b##3), (c.s3)); \
1817 })
1818#elif N0 == 8 // N0 == 8
1819#define ARM_DOT_K0XN0(a, b, c) \
1820 ({ \
1821 ARM_DOT_K0((a), (b##0), (c.s0)); \
1822 ARM_DOT_K0((a), (b##1), (c.s1)); \
1823 ARM_DOT_K0((a), (b##2), (c.s2)); \
1824 ARM_DOT_K0((a), (b##3), (c.s3)); \
1825 ARM_DOT_K0((a), (b##4), (c.s4)); \
1826 ARM_DOT_K0((a), (b##5), (c.s5)); \
1827 ARM_DOT_K0((a), (b##6), (c.s6)); \
1828 ARM_DOT_K0((a), (b##7), (c.s7)); \
1829 })
1830#elif N0 == 16 // N0 == 16
1831#define ARM_DOT_K0XN0(a, b, c) \
1832 ({ \
1833 ARM_DOT_K0((a), (b##0), (c.s0)); \
1834 ARM_DOT_K0((a), (b##1), (c.s1)); \
1835 ARM_DOT_K0((a), (b##2), (c.s2)); \
1836 ARM_DOT_K0((a), (b##3), (c.s3)); \
1837 ARM_DOT_K0((a), (b##4), (c.s4)); \
1838 ARM_DOT_K0((a), (b##5), (c.s5)); \
1839 ARM_DOT_K0((a), (b##6), (c.s6)); \
1840 ARM_DOT_K0((a), (b##7), (c.s7)); \
1841 ARM_DOT_K0((a), (b##8), (c.s8)); \
1842 ARM_DOT_K0((a), (b##9), (c.s9)); \
1843 ARM_DOT_K0((a), (b##A), (c.sA)); \
1844 ARM_DOT_K0((a), (b##B), (c.sB)); \
1845 ARM_DOT_K0((a), (b##C), (c.sC)); \
1846 ARM_DOT_K0((a), (b##D), (c.sD)); \
1847 ARM_DOT_K0((a), (b##E), (c.sE)); \
1848 ARM_DOT_K0((a), (b##F), (c.sF)); \
1849 })
1850#else // N0 not supported
1851#error "N0 value not supported"
1852#endif // N0 conditions
1853
1854/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
1855 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
1856 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
1857 *
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001858 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1859 * @note The data type used for the accumulators must be passed at compile time using -DDATA_TYPE_ACCUMULATOR (e.g. -DDATA_TYPE_ACCUMULATOR=float)
1860 * @note The F16 computation also supports mixed precision through the option -DMIXED_PRECISION passed at compile time. If enabled, DATA_TYPE_ACCUMULATOR should be set to float
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001861 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01001862 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001863 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
1864 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
1865 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001866 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
1867 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
1868 * @note Only the following configurations of M0, N0 and K0 are currently supported:
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01001869 * - M0 = 2, 3, 4, 5, 6, 7, 8
Gian Marco Iodicebacfec52019-01-11 11:30:55 +00001870 * - N0 = 2, 3, 4, 8, 16
1871 * - K0 = 2, 3, 4, 8, 16
Gian Marco Iodice62251f72019-03-11 16:07:12 +00001872 * - V0 >= 1
1873 * - H0 >= 1
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001874 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001875 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01001876 * The activation function is performed after the bias addition
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01001877 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001878 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1879 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1880 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1881 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
1882 *
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001883 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
1884 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
1885 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1886 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
1887 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1888 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
1889 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
1890 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
1891 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1892 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
1893 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1894 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
1895 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1896 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1897 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
1898 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1899 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
1900 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1901 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
1902 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1903 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1904 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1905 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1906 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01001907 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001908 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
1909 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
1910 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1911 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1912 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001913 */
1914__kernel void gemm_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
1915 IMAGE_DECLARATION(rhs),
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001916#if defined(BETA)
1917 IMAGE_DECLARATION(bias),
1918#endif // defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001919 IMAGE_DECLARATION(dst),
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01001920 uint k,
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001921 uint lhs_stride_z,
1922 uint rhs_stride_z,
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001923#if defined(BETA)
1924 uint bias_stride_z,
1925#endif //defined(BETA)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001926 uint dst_stride_z
1927#if defined(REINTERPRET_OUTPUT_AS_3D)
1928 ,
1929 uint dst_cross_plane_pad
1930#endif // REINTERPRET_OUTPUT_AS_3D
1931 )
1932{
1933 // Block size
1934#define LHS_BLOCK_SIZE ((K0) * (M0))
1935
1936#if defined(LHS_INTERLEAVE)
1937#define LHS_OFFSET_X (K0)
1938#define LHS_STEP_X ((K0) * (V0))
1939#define LHS_STEP_LOOP (1)
1940#else // defined(INTERLEAVE)
1941#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
1942#define LHS_STEP_X (K0)
1943#define LHS_STEP_LOOP (V0)
1944#endif // defined(INTERLEAVE)
1945
1946 // Block size
1947#define RHS_BLOCK_SIZE ((K0) * (N0))
1948
1949 // RHS offset and step X
1950#if defined(RHS_INTERLEAVE)
1951#define RHS_OFFSET_X (K0)
1952#define RHS_STEP_X ((K0) * (H0))
1953#define RHS_STEP_LOOP (1)
1954#else // defined(RHS_INTERLEAVE)
1955#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
1956#define RHS_STEP_X (K0)
1957#define RHS_STEP_LOOP (H0)
1958#endif // defined(RHS_INTERLEAVE)
1959
Gian Marco Iodiceb0c50372019-03-15 10:13:05 +00001960#if defined(DUMMY_WORK_ITEMS)
1961 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
1962 {
1963 return;
1964 }
1965#endif // defined(DUMMY_WORK_ITEMS)
1966
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001967 // Compute LHS matrix address
1968 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
1969 (get_global_id(2) * lhs_stride_z);
1970
1971 // Compute RHS matrix address
1972 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (get_global_id(0) % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(0) / (uint)H0) * rhs_stride_y;
1973
1974#if defined(MATRIX_B_DEPTH)
1975 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1976 rhs_addr += (get_global_id(2) % MATRIX_B_DEPTH) * rhs_stride_z;
1977#else // defined(MATRIX_B_DEPTH)
1978 rhs_addr += get_global_id(2) * rhs_stride_z;
1979#endif // defined(MATRIX_B_DEPTH)
1980
1981 // Initialize the accumulators
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01001982 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001983
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01001984 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
1985 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
Usama Arif0681e3b2019-04-25 14:28:07 +01001986
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01001987 for(int i = 0; i < k; i += K0)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001988 {
1989 // Supported cases (M0, K0):
Gian Marco Iodiceadc53952019-02-15 11:10:31 +00001990 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
1991 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
1992 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
1993 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
1994 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
1995 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
1996 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
1997 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00001998 // Load values from LHS matrix
Usama Arif0681e3b2019-04-25 14:28:07 +01001999 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002000
2001 // Load values from RHS matrix
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002002 LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X * sizeof(DATA_TYPE), zero);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002003
2004 // Accumulate
2005 ARM_DOT_K0XN0(a0, b, c0);
2006#if M0 > 1
2007 ARM_DOT_K0XN0(a1, b, c1);
2008#endif // M0 > 1
2009#if M0 > 2
2010 ARM_DOT_K0XN0(a2, b, c2);
2011#endif // M0 > 2
2012#if M0 > 3
2013 ARM_DOT_K0XN0(a3, b, c3);
2014#endif // M0 > 3
2015#if M0 > 4
2016 ARM_DOT_K0XN0(a4, b, c4);
2017#endif // M0 > 4
2018#if M0 > 5
2019 ARM_DOT_K0XN0(a5, b, c5);
2020#endif // M0 > 5
2021#if M0 > 6
2022 ARM_DOT_K0XN0(a6, b, c6);
2023#endif // M0 > 6
2024#if M0 > 7
2025 ARM_DOT_K0XN0(a7, b, c7);
2026#endif // M0 > 7
2027
2028 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
2029 rhs_addr += (N0 * RHS_STEP_X * RHS_STEP_LOOP) * sizeof(DATA_TYPE);
2030 }
2031
2032 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2033
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002034 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002035
2036#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002037
2038 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
Usama Arif0681e3b2019-04-25 14:28:07 +01002039 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002040 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2041 // multiply dst_stride_z by DEPTH_GEMM3D
2042 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2043
2044#else // defined(REINTERPRET_OUTPUT_AS_3D)
2045
2046 // Add offset for batched GEMM
2047 dst_addr += get_global_id(2) * dst_stride_z;
2048
2049#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2050
2051 // Multiply by the weight of matrix-matrix product and store the result
2052#if defined(ALPHA)
Usama Arif0681e3b2019-04-25 14:28:07 +01002053 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002054#endif // defined(ALPHA)
2055
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002056 // Add beta*bias
2057#if defined(BETA)
2058#if defined(BROADCAST_BIAS)
2059 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2060
2061 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2062
2063#ifndef UNIT_BETA
2064 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2065#endif // UNIT_BIAS
2066
2067 // c = c + bias[broadcasted]
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002068#if defined(MIXED_PRECISION)
2069 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2070 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2071#else // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002072 ADD_BLOCK_BROADCAST(M0, c, bias0);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002073#endif // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002074
2075#else // defined(BROADCAST_BIAS)
2076 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2077 2) * bias_stride_z;
2078
2079 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2080
2081#ifndef UNIT_BETA
2082 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2083#endif // UNIT_BIAS
2084
2085 // c = c + bias
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002086#if defined(MIXED_PRECISION)
2087 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2088 ADD_BLOCK(M0, c, bias_hp);
2089#else // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002090 ADD_BLOCK(M0, c, bias);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002091#endif // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002092
2093#endif // defined(BROADCAST_BIAS)
2094#endif // defined(BETA)
2095
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002096#if defined(ACTIVATION_TYPE)
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002097#if defined(MIXED_PRECISION)
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002098 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002099#else // defined(MIXED_PRECISION)
2100 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2101#endif // defined(MIXED_PRECISION)
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01002102#endif // defined(ACTIVATION_TYPE)
2103
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002104 // Store output block
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002105#if defined(MIXED_PRECISION)
2106 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2107#else // defined(MIXED_PRECISION)
Usama Arif0681e3b2019-04-25 14:28:07 +01002108 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002109#endif // defined(MIXED_PRECISION)
Gian Marco Iodicee16c8902019-06-14 16:11:10 +01002110
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002111#undef LHS_BLOCK_SIZE
2112#undef LHS_OFFSET_X
2113#undef LHS_STEP_X
2114#undef RHS_BLOCK_SIZE
2115#undef RHS_OFFSET_X
2116#undef RHS_STEP_X
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002117#undef LHS_STEP_LOOP
2118#undef RHS_STEP_LOOP
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00002119}
giuros01b3204e72019-04-01 13:50:22 +01002120
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002121#if defined(OPENCL_IMAGE_SUPPORT)
2122/** This OpenCL kernel computes the matrix multiplication between 2 matrices. The RHS matrix is stored in OpenCL image object.
2123 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
2124 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
2125 *
2126 * @note -DOPENCL_IMAGE_SUPPORT must be passed at compile time in order to compile this OpenCL kernel
2127 * @note The data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
2128 * @note The data type used for the accumulators must be passed at compile time using -DDATA_TYPE_ACCUMULATOR (e.g. -DDATA_TYPE_ACCUMULATOR=float)
2129 * @note The F16 computation also supports mixed precision through the option -DMIXED_PRECISION passed at compile time. If enabled, DATA_TYPE_ACCUMULATOR should be set to float
2130 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2131 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
2132 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2133 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2134 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2135 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2136 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2137 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2138 * - M0 = 2, 3, 4, 5, 6, 7, 8
2139 * - N0 = 4, 8, 16
2140 * - K0 = 4, 8, 16
2141 * - V0 >= 1
2142 * - H0 >= 1
2143 *
2144 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2145 * The activation function is performed after the bias addition
2146 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2147 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2148 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2149 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2150 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2151 *
2152 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F32
2153 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2154 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2155 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2156 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2157 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2158 * @param[in] rhs_img The RHS reshaped matrix as OpenCL image object. Supported data type: same as @p lhs_ptr
2159 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2160 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2161 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2162 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2163 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2164 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2165 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2166 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2167 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2168 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2169 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2170 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01002171 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002172 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2173 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2174 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2175 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2176 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2177 */
2178__kernel void gemm_mm_reshaped_lhs_nt_rhs_t_texture(IMAGE_DECLARATION(lhs),
2179 __read_only image2d_t rhs_img,
2180#if defined(BETA)
2181 IMAGE_DECLARATION(bias),
2182#endif // defined(BETA)
2183 IMAGE_DECLARATION(dst),
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01002184 uint k,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002185 uint lhs_stride_z,
2186 uint rhs_stride_z,
2187#if defined(BETA)
2188 uint bias_stride_z,
2189#endif //defined(BETA)
2190 uint dst_stride_z
2191#if defined(REINTERPRET_OUTPUT_AS_3D)
2192 ,
2193 uint dst_cross_plane_pad
2194#endif // REINTERPRET_OUTPUT_AS_3D
2195 )
2196{
2197 // Pixel unit
2198#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(K0)
2199
2200 // Block size
2201#define LHS_BLOCK_SIZE ((K0) * (M0))
2202
2203#if defined(LHS_INTERLEAVE)
2204#define LHS_OFFSET_X (K0)
2205#define LHS_STEP_X ((K0) * (V0))
2206#define LHS_STEP_LOOP (1)
2207#else // defined(INTERLEAVE)
2208#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2209#define LHS_STEP_X (K0)
2210#define LHS_STEP_LOOP (V0)
2211#endif // defined(INTERLEAVE)
2212
2213 // Block size
2214#define RHS_BLOCK_SIZE (PIXEL_UNIT * (N0))
2215
2216 // RHS offset and step X
2217#if defined(RHS_INTERLEAVE)
2218#define RHS_OFFSET_X (PIXEL_UNIT)
2219#define RHS_STEP_X (PIXEL_UNIT * (H0))
2220#define RHS_STEP_LOOP (1)
2221#else // defined(RHS_INTERLEAVE)
2222#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2223#define RHS_STEP_X PIXEL_UNIT
2224#define RHS_STEP_LOOP (H0)
2225#endif // defined(RHS_INTERLEAVE)
2226
2227#if defined(DUMMY_WORK_ITEMS)
2228 if((get_global_id(0) * N0 >= N) || (get_global_id(1) * M0 >= M))
2229 {
2230 return;
2231 }
2232#endif // defined(DUMMY_WORK_ITEMS)
2233
2234 // Compute LHS matrix address
2235 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (get_global_id(1) % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (get_global_id(1) / V0) * (uint)lhs_stride_y +
2236 (get_global_id(2) * lhs_stride_z);
2237
2238#if defined(MATRIX_B_DEPTH)
2239 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2240 const uint z_rhs = (get_global_id(2) % MATRIX_B_DEPTH);
2241#else // defined(MATRIX_B_DEPTH)
2242 const uint z_rhs = get_global_id(2);
2243#endif // defined(MATRIX_B_DEPTH)
2244
2245 // Compute RHS matrix coordinates
2246 uint x_rhs = (get_global_id(0) % H0) * (uint)RHS_OFFSET_X;
2247 const uint y_rhs = (get_global_id(0) / (uint)H0) + z_rhs * RHS_HEIGHT;
2248
2249 // Initialize the accumulators
2250 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
2251
2252 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0); //uint zlhs0=0,zlhs1=0,zlhs2=0,... zlhs7=0;
2253 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
2254
2255 for(int i = 0; i < K; i += K0)
2256 {
2257 // Load values from LHS matrix
2258 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X * sizeof(DATA_TYPE), zlhs);
2259
2260 // Load values from RHS matrix stored in a cl_image
2261 REPEAT_VAR_INIT_TO_CONST(N0, VEC_DATA_TYPE(DATA_TYPE, K0), b, 0);
2262 LOAD_TEXTURE2D(N0, PIXEL_UNIT, DATA_TYPE, b, rhs_img, x_rhs, y_rhs, RHS_STEP_X, 0);
2263
2264 // Accumulate
2265 ARM_DOT_K0XN0(a0, b, c0);
2266#if M0 > 1
2267 ARM_DOT_K0XN0(a1, b, c1);
2268#endif // M0 > 1
2269#if M0 > 2
2270 ARM_DOT_K0XN0(a2, b, c2);
2271#endif // M0 > 2
2272#if M0 > 3
2273 ARM_DOT_K0XN0(a3, b, c3);
2274#endif // M0 > 3
2275#if M0 > 4
2276 ARM_DOT_K0XN0(a4, b, c4);
2277#endif // M0 > 4
2278#if M0 > 5
2279 ARM_DOT_K0XN0(a5, b, c5);
2280#endif // M0 > 5
2281#if M0 > 6
2282 ARM_DOT_K0XN0(a6, b, c6);
2283#endif // M0 > 6
2284#if M0 > 7
2285 ARM_DOT_K0XN0(a7, b, c7);
2286#endif // M0 > 7
2287
2288 lhs_addr += (M0 * LHS_STEP_X * LHS_STEP_LOOP) * sizeof(DATA_TYPE);
2289
2290 x_rhs += N0 * RHS_STEP_X * RHS_STEP_LOOP;
2291 }
2292
2293 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * dst_stride_y);
2294
2295 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
2296
2297#if defined(REINTERPRET_OUTPUT_AS_3D)
2298
2299 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2300 CALCULATE_Z_OFFSET(M0, uint, zout, get_global_id(1), HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2301 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2302 // multiply dst_stride_z by DEPTH_GEMM3D
2303 dst_addr += get_global_id(2) * dst_stride_z * DEPTH_GEMM3D;
2304
2305#else // defined(REINTERPRET_OUTPUT_AS_3D)
2306
2307 // Add offset for batched GEMM
2308 dst_addr += get_global_id(2) * dst_stride_z;
2309
2310#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2311
2312 // Multiply by the weight of matrix-matrix product and store the result
2313#if defined(ALPHA)
2314 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2315#endif // defined(ALPHA)
2316
2317 // Add beta*bias
2318#if defined(BETA)
2319#if defined(BROADCAST_BIAS)
2320 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
2321
2322 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2323
2324#ifndef UNIT_BETA
2325 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2326#endif // UNIT_BIAS
2327
2328 // c = c + bias[broadcasted]
2329#if defined(MIXED_PRECISION)
2330 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2331 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2332#else // defined(MIXED_PRECISION)
2333 ADD_BLOCK_BROADCAST(M0, c, bias0);
2334#endif // defined(MIXED_PRECISION)
2335
2336#else // defined(BROADCAST_BIAS)
2337 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
2338 2) * bias_stride_z;
2339
2340 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2341
2342#ifndef UNIT_BETA
2343 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2344#endif // UNIT_BIAS
2345
2346 // c = c + bias
2347#if defined(MIXED_PRECISION)
2348 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2349 ADD_BLOCK(M0, c, bias_hp);
2350#else // defined(MIXED_PRECISION)
2351 ADD_BLOCK(M0, c, bias);
2352#endif // defined(MIXED_PRECISION)
2353
2354#endif // defined(BROADCAST_BIAS)
2355#endif // defined(BETA)
2356
2357#if defined(ACTIVATION_TYPE)
2358#if defined(MIXED_PRECISION)
2359 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
2360#else // defined(MIXED_PRECISION)
2361 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
2362#endif // defined(MIXED_PRECISION)
2363#endif // defined(ACTIVATION_TYPE)
2364
2365 // Store output block
2366#if defined(MIXED_PRECISION)
2367 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2368#else // defined(MIXED_PRECISION)
2369 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2370#endif // defined(MIXED_PRECISION)
2371
2372#undef LHS_BLOCK_SIZE
2373#undef LHS_OFFSET_X
2374#undef LHS_STEP_X
2375#undef RHS_BLOCK_SIZE
2376#undef RHS_OFFSET_X
2377#undef RHS_STEP_X
2378#undef PIXEL_UNIT
2379#undef LHS_STEP_LOOP
2380#undef RHS_STEP_LOOP
2381}
2382#endif // defined(OPENCL_IMAGE_SUPPORT)
2383
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002384#if defined(LHS_TRANSPOSE)
2385
2386#define VTYPE(TYPE, SIZE) VEC_DATA_TYPE(TYPE, SIZE)
2387
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002388#if defined(MIXED_PRECISION)
2389
2390#if(GPU_ARCH == GPU_ARCH_MIDGARD)
2391#define ARM_VFMA(N0, a, b, c) c += (CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))) * (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0)));
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002392#else // GPU_ARCH == GPU_ARCH_MIDGARD
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002393#define ARM_VFMA(N0, a, b, c) c = fma((CONVERT(a, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (CONVERT(b, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0))), (c));
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002394#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2395
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002396#else // defined(MIXED_PRECISION
2397
2398#if(GPU_ARCH == GPU_ARCH_MIDGARD)
2399#define ARM_VFMA(N0, a, b, c) c += (a) * (b);
2400#else // GPU_ARCH == GPU_ARCH_MIDGARD
2401#define ARM_VFMA(N0, a, b, c) c = fma((a), (b), (c));
2402#endif // GPU_ARCH == GPU_ARCH_MIDGARD
2403
2404#endif // defined(MIXED_PRECISION)
2405
2406#define ARM_VVM_T_NT_1xN0x1(N0, TYPE, a, b, C) \
2407 ({ \
2408 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a), b, (C##0)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002409 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002410#define ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C) \
2411 ({ \
2412 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s0), b, (C##0)); \
2413 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s1), b, (C##1)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002414 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002415#define ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C) \
2416 ({ \
2417 ARM_VVM_T_NT_2xN0x1(N0, TYPE, a, b, C); \
2418 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s2), b, (C##2)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002419 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002420#define ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C) \
2421 ({ \
2422 ARM_VVM_T_NT_3xN0x1(N0, TYPE, a, b, C); \
2423 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s3), b, (C##3)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002424 })
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002425#define ARM_VVM_T_NT_8xN0x1(N0, TYPE, a, b, C) \
2426 ({ \
2427 ARM_VVM_T_NT_4xN0x1(N0, TYPE, a, b, C); \
2428 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s4), b, (C##4)); \
2429 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s5), b, (C##5)); \
2430 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s6), b, (C##6)); \
2431 ARM_VFMA(N0, (VTYPE(TYPE, N0))(a.s7), b, (C##7)); \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002432 })
2433
2434// Factory macro for the column-vector (transposed) by row-vector (not transposed) multiplication. K0 = 1
2435// a is the column-vector (transposed)
2436// b is the row-vector (not transposed)
2437// C is the output matrix
2438// Lower case is a vector (a, b)
2439// Upper case is a matrix (C)
2440#define ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, a, b, C) ARM_VVM_T_NT_##M0##xN0x1(N0, TYPE, a, b, C)
2441
2442#define ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C) \
2443 ({ \
2444 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##0), (B##0), C); \
2445 })
2446#define ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C) \
2447 ({ \
2448 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, A, B, C); \
2449 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##1), (B##1), C); \
2450 })
2451#define ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C) \
2452 ({ \
2453 ARM_MM_T_NT_M0xN0x2(M0, N0, TYPE, A, B, C); \
2454 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##2), (B##2), C); \
2455 })
2456#define ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C) \
2457 ({ \
2458 ARM_MM_T_NT_M0xN0x3(M0, N0, TYPE, A, B, C); \
2459 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##3), (B##3), C); \
2460 })
2461#define ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C) \
2462 ({ \
2463 ARM_MM_T_NT_M0xN0x4(M0, N0, TYPE, A, B, C); \
2464 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##4), (B##4), C); \
2465 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##5), (B##5), C); \
2466 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##6), (B##6), C); \
2467 ARM_VVM_T_NT_M0xN0x1(M0, N0, TYPE, (A##7), (B##7), C); \
2468 })
2469#define ARM_MM_T_NT_M0xN0x16(M0, N0, TYPE, A, B, C) \
2470 ({ \
2471 ARM_MM_T_NT_M0xN0x8(M0, N0, TYPE, A, B, C); \
2472 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##8), (B##8), C); \
2473 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##9), (B##9), C); \
2474 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##A), (B##A), C); \
2475 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##B), (B##B), C); \
2476 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##C), (B##C), C); \
2477 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##D), (B##D), C); \
2478 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##E), (B##E), C); \
2479 ARM_MM_T_NT_M0xN0x1(M0, N0, TYPE, (A##F), (B##F), C); \
2480 })
2481
2482// Factory macro for the matrix (transposed) by matrix (not transposed) multiplication.
2483// The dimensions for this matrix multiplications are defined through M0, N0 and K0
2484// The dimensions supported are:
2485// M0: 1, 2, 3, 4, 8
2486// N0: 1, 2, 3, 4, 8, 16
2487// K0: 1, 2, 3, 4, 8, 16
2488// This macro calls the vector-by-matrix macro K0 times
2489// A, B and C are matrices
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002490#define ARM_MM_T_NT(M0, N0, K0, TYPE, A, B, C) \
2491 CONCAT(ARM_MM_T_NT_M0xN0x, K0) \
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002492 (M0, N0, TYPE, A, B, C)
2493
2494/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
2495 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be transposed
2496 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be NOT transposed
2497 *
2498 * @note LHS_TRANSPOSE should be passed at compile time in order to compile this OpenCL kernel (e.g. -DLHS_TRANSPOSE).
2499 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002500 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002501 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2502 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2503 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2504 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2505 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2506 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2507 * - M0 = 2, 3, 4, 8
2508 * - N0 = 2, 3, 4, 8, 16
2509 * - K0 = 2, 3, 4, 8, 16
2510 * - V0 >= 1
2511 * - H0 >= 1
2512 *
2513 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2514 * The activation function is performed after the bias addition
2515 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2516 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2517 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2518 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2519 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2520 *
2521 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
2522 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2523 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2524 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2525 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2526 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2527 * @param[in] rhs_ptr Pointer to the RHS reshaped matrix. Supported data type: same as @p lhs_ptr
2528 * @param[in] rhs_stride_x Stride of the RHS reshaped matrix in X dimension (in bytes)
2529 * @param[in] rhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2530 * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
2531 * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2532 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
2533 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2534 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2535 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2536 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2537 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2538 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2539 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2540 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2541 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2542 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2543 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2544 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01002545 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002546 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2547 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2548 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2549 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2550 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2551 */
2552__kernel void gemm_mm_reshaped_lhs_t_rhs_nt(IMAGE_DECLARATION(lhs),
2553 IMAGE_DECLARATION(rhs),
2554#if defined(BETA)
2555 IMAGE_DECLARATION(bias),
2556#endif // defined(BETA)
2557 IMAGE_DECLARATION(dst),
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01002558 uint k,
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002559 uint lhs_stride_z,
2560 uint rhs_stride_z,
2561#if defined(BETA)
2562 uint bias_stride_z,
2563#endif //defined(BETA)
2564 uint dst_stride_z
2565#if defined(REINTERPRET_OUTPUT_AS_3D)
2566 ,
2567 uint dst_cross_plane_pad
2568#endif // REINTERPRET_OUTPUT_AS_3D
2569 )
2570{
2571 // Block size
2572#define LHS_BLOCK_SIZE ((K0) * (M0))
2573
2574#if defined(LHS_INTERLEAVE)
2575#define LHS_OFFSET_X (M0)
2576#define LHS_STEP_X ((M0) * (V0))
2577#define LHS_STEP_LOOP (1)
2578#else // defined(INTERLEAVE)
2579#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2580#define LHS_STEP_X (M0)
2581#define LHS_STEP_LOOP (V0)
2582#endif // defined(INTERLEAVE)
2583
2584 // Block size
2585#define RHS_BLOCK_SIZE ((K0) * (N0))
2586
2587 // RHS offset and step X
2588#if defined(RHS_INTERLEAVE)
2589#define RHS_OFFSET_X (N0)
2590#define RHS_STEP_X ((N0) * (H0))
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002591#else // defined(RHS_INTERLEAVE)
2592#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2593#define RHS_STEP_X (N0)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002594#endif // defined(RHS_INTERLEAVE)
2595
2596 const uint x = get_global_id(0);
2597 const uint y = get_global_id(1);
2598 const uint z = get_global_id(2);
2599
2600#if defined(DUMMY_WORK_ITEMS)
2601 if((x * N0 >= N) || (y * M0 >= M))
2602 {
2603 return;
2604 }
2605#endif // defined(DUMMY_WORK_ITEMS)
2606
2607 // Compute LHS matrix address
2608 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2609
2610 // Compute RHS matrix address
2611 __global uchar *rhs_addr = rhs_ptr + rhs_offset_first_element_in_bytes + (x % H0) * (uint)RHS_OFFSET_X * sizeof(DATA_TYPE) + (x / (uint)H0) * rhs_stride_y;
2612
2613#if defined(MATRIX_B_DEPTH)
2614 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2615 rhs_addr += (z % MATRIX_B_DEPTH) * rhs_stride_z;
2616#else // defined(MATRIX_B_DEPTH)
2617 rhs_addr += z * rhs_stride_z;
2618#endif // defined(MATRIX_B_DEPTH)
2619
2620 // Initialize the accumulators
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002621 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002622
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002623 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2624
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002625 __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
2626 __global DATA_TYPE *rhs = (__global DATA_TYPE *)(rhs_addr);
2627
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01002628 for(int i = 0; i < k; i += K0)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002629 {
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002630 VEC_DATA_TYPE(DATA_TYPE, M0)
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002631 a0;
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002632 VEC_DATA_TYPE(DATA_TYPE, N0)
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002633 b0;
2634
2635 a0 = VLOAD(M0)(0, lhs);
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002636 b0 = VLOAD(N0)(0, rhs);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002637
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002638 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002639
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002640 lhs += LHS_STEP_X;
2641 rhs += RHS_STEP_X;
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002642
Gian Marco Iodice05639f62019-09-24 12:05:06 +01002643#if K0 > 1
2644 a0 = VLOAD(M0)(0, lhs);
2645 b0 = VLOAD(N0)(0, rhs);
2646
2647 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2648
2649 lhs += LHS_STEP_X;
2650 rhs += RHS_STEP_X;
2651#endif // K0 > 1
2652
2653#if K0 > 2
2654 a0 = VLOAD(M0)(0, lhs);
2655 b0 = VLOAD(N0)(0, rhs);
2656
2657 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2658
2659 lhs += LHS_STEP_X;
2660 rhs += RHS_STEP_X;
2661#endif // K0 > 2
2662
2663#if K0 > 3
2664 a0 = VLOAD(M0)(0, lhs);
2665 b0 = VLOAD(N0)(0, rhs);
2666
2667 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2668
2669 lhs += LHS_STEP_X;
2670 rhs += RHS_STEP_X;
2671#endif // K0 > 3
2672
2673#if K0 > 4
2674 a0 = VLOAD(M0)(0, lhs);
2675 b0 = VLOAD(N0)(0, rhs);
2676
2677 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2678
2679 lhs += LHS_STEP_X;
2680 rhs += RHS_STEP_X;
2681
2682 a0 = VLOAD(M0)(0, lhs);
2683 b0 = VLOAD(N0)(0, rhs);
2684
2685 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2686
2687 lhs += LHS_STEP_X;
2688 rhs += RHS_STEP_X;
2689
2690 a0 = VLOAD(M0)(0, lhs);
2691 b0 = VLOAD(N0)(0, rhs);
2692
2693 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2694
2695 lhs += LHS_STEP_X;
2696 rhs += RHS_STEP_X;
2697
2698 a0 = VLOAD(M0)(0, lhs);
2699 b0 = VLOAD(N0)(0, rhs);
2700
2701 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2702
2703 lhs += LHS_STEP_X;
2704 rhs += RHS_STEP_X;
2705#endif // K0 > 4
2706
2707#if K0 > 8
2708 a0 = VLOAD(M0)(0, lhs);
2709 b0 = VLOAD(N0)(0, rhs);
2710
2711 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2712
2713 lhs += LHS_STEP_X;
2714 rhs += RHS_STEP_X;
2715
2716 a0 = VLOAD(M0)(0, lhs);
2717 b0 = VLOAD(N0)(0, rhs);
2718
2719 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2720
2721 lhs += LHS_STEP_X;
2722 rhs += RHS_STEP_X;
2723
2724 a0 = VLOAD(M0)(0, lhs);
2725 b0 = VLOAD(N0)(0, rhs);
2726
2727 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2728
2729 lhs += LHS_STEP_X;
2730 rhs += RHS_STEP_X;
2731
2732 a0 = VLOAD(M0)(0, lhs);
2733 b0 = VLOAD(N0)(0, rhs);
2734
2735 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2736
2737 lhs += LHS_STEP_X;
2738 rhs += RHS_STEP_X;
2739
2740 a0 = VLOAD(M0)(0, lhs);
2741 b0 = VLOAD(N0)(0, rhs);
2742
2743 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2744
2745 lhs += LHS_STEP_X;
2746 rhs += RHS_STEP_X;
2747
2748 a0 = VLOAD(M0)(0, lhs);
2749 b0 = VLOAD(N0)(0, rhs);
2750
2751 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2752
2753 lhs += LHS_STEP_X;
2754 rhs += RHS_STEP_X;
2755
2756 a0 = VLOAD(M0)(0, lhs);
2757 b0 = VLOAD(N0)(0, rhs);
2758
2759 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2760
2761 lhs += LHS_STEP_X;
2762 rhs += RHS_STEP_X;
2763
2764 a0 = VLOAD(M0)(0, lhs);
2765 b0 = VLOAD(N0)(0, rhs);
2766
2767 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
2768
2769 lhs += LHS_STEP_X;
2770 rhs += RHS_STEP_X;
2771#endif // K0 > 8
2772
2773#ifndef LHS_INTERLEAVE
2774 lhs += (M0 * K0 * (V0 - 1));
2775#endif // LHS_INTERLEAVE
2776
2777#ifndef RHS_INTERLEAVE
2778 rhs += (N0 * K0 * (H0 - 1));
2779#endif // RHS_INTERLEAVE
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002780 }
2781
2782 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
2783
2784 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
2785
2786#if defined(REINTERPRET_OUTPUT_AS_3D)
2787
2788 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
2789 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
2790 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2791 // multiply dst_stride_z by DEPTH_GEMM3D
2792 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2793
2794#else // defined(REINTERPRET_OUTPUT_AS_3D)
2795
2796 // Add offset for batched GEMM
2797 dst_addr += z * dst_stride_z;
2798
2799#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2800
2801 // Multiply by the weight of matrix-matrix product and store the result
2802#if defined(ALPHA)
2803 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
2804#endif // defined(ALPHA)
2805
2806 // Add beta*bias
2807#if defined(BETA)
2808#if defined(BROADCAST_BIAS)
2809 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE));
2810
2811 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2812
2813#ifndef UNIT_BETA
2814 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
2815#endif // UNIT_BIAS
2816
2817 // c = c + bias[broadcasted]
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002818#if defined(MIXED_PRECISION)
2819 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2820 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
2821#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002822 ADD_BLOCK_BROADCAST(M0, c, bias0);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002823#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002824
2825#else // defined(BROADCAST_BIAS)
2826 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
2827
2828 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
2829
2830#ifndef UNIT_BETA
2831 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
2832#endif // UNIT_BIAS
2833
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002834#if defined(MIXED_PRECISION)
2835 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
2836 ADD_BLOCK(M0, c, bias_hp);
2837#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002838 ADD_BLOCK(M0, c, bias);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002839#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002840
2841#endif // defined(BROADCAST_BIAS)
2842#endif // defined(BETA)
2843
2844#if defined(ACTIVATION_TYPE)
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002845#if defined(MIXED_PRECISION)
2846 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
2847#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002848 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
Georgios Pinitasa07ce152019-10-11 17:38:50 +01002849#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002850#endif // defined(ACTIVATION_TYPE)
2851
2852 // Store output block
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002853#if defined(MIXED_PRECISION)
2854 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
2855#else // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002856 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
Gian Marco Iodice0c17aa22019-09-27 09:23:15 +01002857#endif // defined(MIXED_PRECISION)
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01002858
2859#undef LHS_BLOCK_SIZE
2860#undef LHS_OFFSET_X
2861#undef LHS_STEP_X
2862#undef RHS_BLOCK_SIZE
2863#undef RHS_OFFSET_X
2864#undef RHS_STEP_X
2865}
2866
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002867#if defined(OPENCL_IMAGE_SUPPORT)
2868/** This OpenCL kernel computes the matrix multiplication between 2 matrices. The RHS matrix is stored in OpenCL image object.
2869 * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be transposed
2870 * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be NOT transposed
2871 *
2872 * @note -DOPENCL_IMAGE_SUPPORT must be passed at compile time in order to compile this OpenCL kernel
2873 * @note LHS_TRANSPOSE should be passed at compile time in order to compile this OpenCL kernel (e.g. -DLHS_TRANSPOSE).
2874 * @note The height of the RHS matrix should be passed at compile time using -DRHS_HEIGHT=<value> (e.g. -DRHS_HEIGHT=32)
2875 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
2876 * @note The GEMM's dimensions M, N and K must be passed at compile time using -DM, -DN and -DK (e.g. -DM=52, -DN=90 and -DK=24).
2877 * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (e.g. -DM0=4, -DN0=8, -DK0=4).
2878 * @note The number of M0xK0 vertical blocks stored on the same output row of the reshaped LHS matrix must be passed at compile time using -DV0 (e.g. -DV0=2)
2879 * @note The number of K0xN0 horizontal blocks stored on the same output row of the reshaped RHS matrix must be passed at compile time using -DH0 (e.g. -DH0=2)
2880 * @note If the M0xK0 blocks in the reshaped LHS matrix have been interleaved, the option -DLHS_INTERLEAVE must passed at compile time.
2881 * @note If the K0xN0 blocks in the reshaped RHS matrix have been interleaved, the option -DRHS_INTERLEAVE must passed at compile time.
2882 * @note Only the following configurations of M0, N0 and K0 are currently supported:
2883 * - M0 = 2, 3, 4, 8
2884 * - N0 = 4, 8, 16
2885 * - K0 = 4, 8, 16
2886 * - V0 >= 1
2887 * - H0 >= 1
2888 *
2889 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2890 * The activation function is performed after the bias addition
2891 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
2892 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2893 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2894 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2895 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix NOT reshaped
2896 *
2897 * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F32
2898 * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
2899 * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2900 * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
2901 * @param[in] lhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2902 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS reshaped matrix
2903 * @param[in] rhs_img The RHS reshaped matrix as cl_image 2d. Supported data type: same as @p lhs_ptr
2904 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2905 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2906 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
2907 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2908 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
2909 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2910 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
2911 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2912 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2913 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2914 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2915 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01002916 * @param[in] k Number of columns in LHS matrix and rows in RHS matrix not reshaped.
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002917 * @param[in] lhs_stride_z Stride of the LHS reshaped matrix in Z dimension (in bytes)
2918 * @param[in] rhs_stride_z Stride of the RHS reshaped matrix in Z dimension (in bytes)
2919 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2920 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2921 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2922 */
2923__kernel void gemm_mm_reshaped_lhs_t_rhs_nt_texture(IMAGE_DECLARATION(lhs),
2924 __read_only image2d_t rhs_img,
2925#if defined(BETA)
2926 IMAGE_DECLARATION(bias),
2927#endif // defined(BETA)
2928 IMAGE_DECLARATION(dst),
Gian Marco Iodicee5563d92020-06-25 17:18:36 +01002929 uint k,
Gian Marco Iodicee3a849a2020-06-10 17:59:30 +01002930 uint lhs_stride_z,
2931 uint rhs_stride_z,
2932#if defined(BETA)
2933 uint bias_stride_z,
2934#endif //defined(BETA)
2935 uint dst_stride_z
2936#if defined(REINTERPRET_OUTPUT_AS_3D)
2937 ,
2938 uint dst_cross_plane_pad
2939#endif // REINTERPRET_OUTPUT_AS_3D
2940 )
2941{
2942 // Pixel unit
2943#define PIXEL_UNIT CONVERT_VECTOR_SIZE_TO_PIXEL_UNIT(N0)
2944
2945 // Block size
2946#define LHS_BLOCK_SIZE ((K0) * (M0))
2947
2948#if defined(LHS_INTERLEAVE)
2949#define LHS_OFFSET_X (M0)
2950#define LHS_STEP_X ((M0) * (V0))
2951#define LHS_STEP_LOOP (1)
2952#else // defined(INTERLEAVE)
2953#define LHS_OFFSET_X (LHS_BLOCK_SIZE)
2954#define LHS_STEP_X (M0)
2955#define LHS_STEP_LOOP (V0)
2956#endif // defined(INTERLEAVE)
2957
2958 // Block size
2959#define RHS_BLOCK_SIZE ((K0) * (PIXEL_UNIT))
2960
2961 // RHS offset and step X
2962#if defined(RHS_INTERLEAVE)
2963#define RHS_OFFSET_X (PIXEL_UNIT)
2964#define RHS_STEP_X ((PIXEL_UNIT) * (H0))
2965#else // defined(RHS_INTERLEAVE)
2966#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
2967#define RHS_STEP_X (PIXEL_UNIT)
2968#endif // defined(RHS_INTERLEAVE)
2969
2970 const uint x = get_global_id(0);
2971 const uint y = get_global_id(1);
2972 const uint z = get_global_id(2);
2973
2974#if defined(DUMMY_WORK_ITEMS)
2975 if((x * N0 >= N) || (y * M0 >= M))
2976 {
2977 return;
2978 }
2979#endif // defined(DUMMY_WORK_ITEMS)
2980
2981 // Compute LHS matrix address
2982 __global uchar *lhs_addr = lhs_ptr + lhs_offset_first_element_in_bytes + (y % V0) * (uint)LHS_OFFSET_X * sizeof(DATA_TYPE) + (y / V0) * (uint)lhs_stride_y + (z * lhs_stride_z);
2983
2984#if defined(MATRIX_B_DEPTH)
2985 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2986 const uint z_rhs = (z % MATRIX_B_DEPTH);
2987#else // defined(MATRIX_B_DEPTH)
2988 const uint z_rhs = z;
2989#endif // defined(MATRIX_B_DEPTH)
2990
2991 // Compute RHS matrix coordinates
2992 uint x_rhs = (x % H0) * (uint)RHS_OFFSET_X;
2993 const uint y_rhs = (x / (uint)H0) + z_rhs * RHS_HEIGHT;
2994
2995 // Initialize the accumulators
2996 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE_ACCUMULATOR, N0), c, 0);
2997
2998 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2999
3000 __global DATA_TYPE *lhs = (__global DATA_TYPE *)(lhs_addr);
3001
3002 for(int i = 0; i < K; i += K0)
3003 {
3004 VEC_DATA_TYPE(DATA_TYPE, M0)
3005 a0;
3006 VEC_DATA_TYPE(DATA_TYPE, N0)
3007 b0;
3008
3009 a0 = VLOAD(M0)(0, lhs);
3010 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 0 * RHS_STEP_X), (y_rhs));
3011
3012 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3013
3014 lhs += LHS_STEP_X;
3015
3016#if K0 > 1
3017 a0 = VLOAD(M0)(0, lhs);
3018 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 1 * RHS_STEP_X), (y_rhs));
3019
3020 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3021
3022 lhs += LHS_STEP_X;
3023#endif // K0 > 1
3024
3025#if K0 > 2
3026 a0 = VLOAD(M0)(0, lhs);
3027 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 2 * RHS_STEP_X), (y_rhs));
3028
3029 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3030
3031 lhs += LHS_STEP_X;
3032#endif // K0 > 2
3033
3034#if K0 > 3
3035 a0 = VLOAD(M0)(0, lhs);
3036 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 3 * RHS_STEP_X), (y_rhs));
3037
3038 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3039
3040 lhs += LHS_STEP_X;
3041#endif // K0 > 3
3042
3043#if K0 > 4
3044 a0 = VLOAD(M0)(0, lhs);
3045 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 4 * RHS_STEP_X), (y_rhs));
3046
3047 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3048
3049 lhs += LHS_STEP_X;
3050
3051 a0 = VLOAD(M0)(0, lhs);
3052 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 5 * RHS_STEP_X), (y_rhs));
3053
3054 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3055
3056 lhs += LHS_STEP_X;
3057
3058 a0 = VLOAD(M0)(0, lhs);
3059 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 6 * RHS_STEP_X), (y_rhs));
3060
3061 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3062
3063 lhs += LHS_STEP_X;
3064
3065 a0 = VLOAD(M0)(0, lhs);
3066 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 7 * RHS_STEP_X), (y_rhs));
3067
3068 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3069
3070 lhs += LHS_STEP_X;
3071#endif // K0 > 4
3072
3073#if K0 > 8
3074 a0 = VLOAD(M0)(0, lhs);
3075 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 8 * RHS_STEP_X), (y_rhs));
3076
3077 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3078
3079 lhs += LHS_STEP_X;
3080
3081 a0 = VLOAD(M0)(0, lhs);
3082 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 9 * RHS_STEP_X), (y_rhs));
3083
3084 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3085
3086 lhs += LHS_STEP_X;
3087
3088 a0 = VLOAD(M0)(0, lhs);
3089 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 10 * RHS_STEP_X), (y_rhs));
3090
3091 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3092
3093 lhs += LHS_STEP_X;
3094
3095 a0 = VLOAD(M0)(0, lhs);
3096 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 11 * RHS_STEP_X), (y_rhs));
3097
3098 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3099
3100 lhs += LHS_STEP_X;
3101
3102 a0 = VLOAD(M0)(0, lhs);
3103 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 12 * RHS_STEP_X), (y_rhs));
3104
3105 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3106
3107 lhs += LHS_STEP_X;
3108
3109 a0 = VLOAD(M0)(0, lhs);
3110 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 13 * RHS_STEP_X), (y_rhs));
3111
3112 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3113
3114 lhs += LHS_STEP_X;
3115
3116 a0 = VLOAD(M0)(0, lhs);
3117 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 14 * RHS_STEP_X), (y_rhs));
3118
3119 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3120
3121 lhs += LHS_STEP_X;
3122
3123 a0 = VLOAD(M0)(0, lhs);
3124 b0 = READ_IMAGE2D(DATA_TYPE, PIXEL_UNIT, rhs_img, (x_rhs + 15 * RHS_STEP_X), (y_rhs));
3125
3126 ARM_MM_T_NT(M0, N0, 1, DATA_TYPE, a, b, c);
3127
3128 lhs += LHS_STEP_X;
3129#endif // K0 > 8
3130
3131#ifndef LHS_INTERLEAVE
3132 lhs += (M0 * K0 * (V0 - 1));
3133#endif // LHS_INTERLEAVE
3134
3135 x_rhs += K0 * RHS_STEP_X;
3136#ifndef RHS_INTERLEAVE
3137 x_rhs += (PIXEL_UNIT * K0 * (H0 - 1));
3138#endif // RHS_INTERLEAVE
3139 }
3140
3141 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
3142
3143 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
3144
3145#if defined(REINTERPRET_OUTPUT_AS_3D)
3146
3147 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
3148 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
3149 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3150 // multiply dst_stride_z by DEPTH_GEMM3D
3151 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3152
3153#else // defined(REINTERPRET_OUTPUT_AS_3D)
3154
3155 // Add offset for batched GEMM
3156 dst_addr += z * dst_stride_z;
3157
3158#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3159
3160 // Multiply by the weight of matrix-matrix product and store the result
3161#if defined(ALPHA)
3162 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
3163#endif // defined(ALPHA)
3164
3165 // Add beta*bias
3166#if defined(BETA)
3167#if defined(BROADCAST_BIAS)
3168 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE));
3169
3170 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3171
3172#ifndef UNIT_BETA
3173 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
3174#endif // UNIT_BIAS
3175
3176 // c = c + bias[broadcasted]
3177#if defined(MIXED_PRECISION)
3178 CONVERT_BLOCK(1, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
3179 ADD_BLOCK_BROADCAST(M0, c, bias_hp0);
3180#else // defined(MIXED_PRECISION)
3181 ADD_BLOCK_BROADCAST(M0, c, bias0);
3182#endif // defined(MIXED_PRECISION)
3183
3184#else // defined(BROADCAST_BIAS)
3185 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * bias_stride_y) + z * bias_stride_z;
3186
3187 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3188
3189#ifndef UNIT_BETA
3190 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
3191#endif // UNIT_BIAS
3192
3193#if defined(MIXED_PRECISION)
3194 CONVERT_BLOCK(M0, N0, DATA_TYPE_ACCUMULATOR, bias, bias_hp);
3195 ADD_BLOCK(M0, c, bias_hp);
3196#else // defined(MIXED_PRECISION)
3197 ADD_BLOCK(M0, c, bias);
3198#endif // defined(MIXED_PRECISION)
3199
3200#endif // defined(BROADCAST_BIAS)
3201#endif // defined(BETA)
3202
3203#if defined(ACTIVATION_TYPE)
3204#if defined(MIXED_PRECISION)
3205 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE_ACCUMULATOR, c, A_VAL, B_VAL);
3206#else // defined(MIXED_PRECISION)
3207 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
3208#endif // defined(MIXED_PRECISION)
3209#endif // defined(ACTIVATION_TYPE)
3210
3211 // Store output block
3212#if defined(MIXED_PRECISION)
3213 CONVERT_STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
3214#else // defined(MIXED_PRECISION)
3215 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
3216#endif // defined(MIXED_PRECISION)
3217
3218#undef LHS_BLOCK_SIZE
3219#undef LHS_OFFSET_X
3220#undef LHS_STEP_X
3221#undef RHS_BLOCK_SIZE
3222#undef RHS_OFFSET_X
3223#undef RHS_STEP_X
3224#undef PIXEL_UNIT
3225#undef LHS_STEP_LOOP
3226#undef RHS_STEP_LOOP
3227}
3228#endif // defined(OPENCL_IMAGE_SUPPORT)
3229
Giorgio Arenaae99b6e2019-08-01 14:22:12 +01003230#endif // defined(LHS_TRANSPOSE)
3231
Gian Marco Iodicebf9731e2018-12-12 10:18:04 +00003232#endif // defined(M0) && defined(N0) && defined(K0) && defined(V0) && defined(H0) && defined(K) && defined(DATA_TYPE)
3233
giuros01b3204e72019-04-01 13:50:22 +01003234#if defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
3235
3236#define VFMA(a, b, c) \
3237 ({ \
3238 c = fma(a, b, c); \
3239 })
3240
3241#if M0 == 1
3242#define RHS_VFMA_M0xN0(i, a, b, c) \
3243 ({ \
3244 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3245 })
3246#elif M0 == 2 // M0 == 2
3247#define RHS_VFMA_M0xN0(i, a, b, c) \
3248 ({ \
3249 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3250 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3251 })
3252#elif M0 == 3 // M0 == 3
3253#define RHS_VFMA_M0xN0(i, a, b, c) \
3254 ({ \
3255 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3256 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3257 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3258 })
3259#elif M0 == 4 // M0 == 4
3260#define RHS_VFMA_M0xN0(i, a, b, c) \
3261 ({ \
3262 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3263 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3264 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3265 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3266 })
3267#elif M0 == 5 // M0 == 5
3268#define RHS_VFMA_M0xN0(i, a, b, c) \
3269 ({ \
3270 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3271 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3272 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3273 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3274 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3275 })
3276#elif M0 == 6 // M0 == 6
3277#define RHS_VFMA_M0xN0(i, a, b, c) \
3278 ({ \
3279 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3280 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3281 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3282 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3283 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3284 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3285 })
3286#elif M0 == 7 // M0 == 7
3287#define RHS_VFMA_M0xN0(i, a, b, c) \
3288 ({ \
3289 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3290 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3291 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3292 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3293 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3294 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3295 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
3296 })
3297#elif M0 == 8 // M0 == 8
3298#define RHS_VFMA_M0xN0(i, a, b, c) \
3299 ({ \
3300 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##0).s##i), b, (c##0)); \
3301 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##1).s##i), b, (c##1)); \
3302 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##2).s##i), b, (c##2)); \
3303 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##3).s##i), b, (c##3)); \
3304 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##4).s##i), b, (c##4)); \
3305 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##5).s##i), b, (c##5)); \
3306 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##6).s##i), b, (c##6)); \
3307 VFMA((VEC_DATA_TYPE(DATA_TYPE, N0))((a##7).s##i), b, (c##7)); \
3308 })
3309#else // M0 not supported
3310#error "M0 not supported"
3311#endif // M0 not supported
3312
3313/** This OpenCL kernel computes the matrix multiplication between 2 matrices.
3314 * The LHS matrix is NOT reshaped
3315 * The RHS matrix is NOT reshaped
3316 *
3317 * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003318 * @note The GEMM's dimensions (M,N and K) must be passed at compile time using -DM, -DN and and -DK (e.g. -DM=52, -DN=30 and -DK=90)
3319 * @note The number of columns of LHS matrix must be passed at compile time using -DK (e.g. -DK=64)
3320 * @note The number of M0 rows to process must be passed at compile time using -DM0 (e.g. -DM0=2)
3321 * @note The number of K0 partial accumulations must be passed at compile time using -DK0 (e.g., -DK0=2)
3322 * @note The number of N0 columns to process must be passed at compile time using -DN0 (e.g. -DN0=2)
giuros01b3204e72019-04-01 13:50:22 +01003323 * @note Only the following configurations of M0, N0 and K0 are currently supported:
3324 * - M0 = 1, 2, 3, 4, 5, 6, 7, 8
3325 * - N0 = 2, 3, 4, 8, 16
3326 * - K0 = 2, 3, 4, 8, 16
3327 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003328 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01003329 * The activation function is performed after the bias addition
giuros01b3204e72019-04-01 13:50:22 +01003330 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3331 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
3332 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3333 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3334 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3335 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
3336 *
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003337 * @param[in] lhs_ptr Pointer to the LHS matrix. Supported data type: F16/F32
3338 * @param[in] lhs_stride_x Stride of the LHS matrix in X dimension (in bytes)
3339 * @param[in] lhs_step_x lhs_stride_x * number of elements along X processed per workitem(in bytes)
3340 * @param[in] lhs_stride_y Stride of the LHS matrix in Y dimension (in bytes)
3341 * @param[in] lhs_step_y lhs_stride_y * number of elements along Y processed per workitem(in bytes)
3342 * @param[in] lhs_offset_first_element_in_bytes The offset of the first element in the LHS matrix
3343 * @param[in] rhs_ptr Pointer to the RHS matrix. Supported data type: same as @p lhs_ptr
3344 * @param[in] rhs_stride_x Stride of the RHS matrix in X dimension (in bytes)
3345 * @param[in] rhs_step_x rhs_stride_x * number of elements along X processed per workitem(in bytes)
3346 * @param[in] rhs_stride_y Stride of the RHS matrix in Y dimension (in bytes)
3347 * @param[in] rhs_step_y rhs_stride_y * number of elements along Y processed per workitem(in bytes)
3348 * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003349 * @param[in] bias_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3350 * @param[in] bias_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3351 * @param[in] bias_step_x (Optional) bias_stride_x * number of elements along X processed per workitem(in bytes)
3352 * @param[in] bias_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3353 * @param[in] bias_step_y (Optional) bias_stride_y * number of elements along Y processed per workitem(in bytes)
3354 * @param[in] bias_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
3355 * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
3356 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3357 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3358 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3359 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3360 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3361 * @param[in] lhs_stride_z Stride of the LHS matrix in Z dimension (in bytes)
3362 * @param[in] rhs_stride_z Stride of the RHS matrix in Z dimension (in bytes)
3363 * @param[in] bias_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
3364 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3365 * @param[in] lhs_cross_plane_pad (Optional) Bottom paddings for LHS matrix in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
3366 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings for the output matrix in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
giuros01b3204e72019-04-01 13:50:22 +01003367 */
3368__kernel void gemm_mm_native(IMAGE_DECLARATION(lhs),
3369 IMAGE_DECLARATION(rhs),
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003370#if defined(BETA)
3371 IMAGE_DECLARATION(bias),
3372#endif // defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01003373 IMAGE_DECLARATION(dst),
3374 uint lhs_stride_z,
3375 uint rhs_stride_z,
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003376#if defined(BETA)
3377 uint bias_stride_z,
3378#endif //defined(BETA)
giuros01b3204e72019-04-01 13:50:22 +01003379 uint dst_stride_z
3380#if defined(REINTERPRET_INPUT_AS_3D)
3381 ,
3382 uint lhs_cross_plane_pad
3383#endif // REINTERPRET_INPUT_AS_3D
3384#if defined(REINTERPRET_OUTPUT_AS_3D)
3385 ,
3386 uint dst_cross_plane_pad
3387#endif // REINTERPRET_OUTPUT_AS_3D
3388 )
3389{
3390 // Block size
3391#define RHS_BLOCK_SIZE ((K0) * (N0))
3392
3393 // RHS offset and step X
3394#define RHS_OFFSET_X (RHS_BLOCK_SIZE)
3395
3396 uint x = get_global_id(0);
3397 uint y = get_global_id(1);
3398 uint z = get_global_id(2);
3399
3400#if defined(DUMMY_WORK_ITEMS)
3401 if((x * N0 >= N) || (y * M0 >= M))
3402 {
3403 return;
3404 }
3405#endif // defined(DUMMY_WORK_ITEMS)
3406
3407 // Compute LHS matrix address
3408 uint lhs_offset = lhs_offset_first_element_in_bytes + y * M0 * (uint)lhs_stride_y;
3409
3410 // Compute RHS matrix address
3411 uint rhs_offset = rhs_offset_first_element_in_bytes + x * N0 * sizeof(DATA_TYPE);
3412
3413#if defined(MATRIX_B_DEPTH)
3414 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3415 rhs_offset += (z % MATRIX_B_DEPTH) * rhs_stride_z;
3416#else // defined(MATRIX_B_DEPTH)
3417 rhs_offset += z * rhs_stride_z;
3418#endif // defined(MATRIX_B_DEPTH)
3419
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003420 REPEAT_VAR_INIT_TO_CONST(M0, uint, zlhs, 0);
3421 REPEAT_VAR_INIT_TO_CONST(16, uint, zero, 0);
giuros01b3204e72019-04-01 13:50:22 +01003422
3423#if defined(REINTERPRET_INPUT_AS_3D)
3424 // The plane (zlhs) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
3425 CALCULATE_Z_OFFSET(M0, uint, zlhs, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, lhs_cross_plane_pad, lhs_stride_y);
3426
3427 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3428 // multiply lhs_stride_z by DEPTH_GEMM3D
3429 lhs_offset += z * lhs_stride_z * DEPTH_GEMM3D;
3430
3431#else // defined(REINTERPRET_INPUT_AS_3D)
3432
3433 // Add offset for batched GEMM
3434 lhs_offset += z * lhs_stride_z;
3435
3436#endif // defined(REINTERPRET_INPUT_AS_3D)
3437
3438 // Initialize the accumulators
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003439 REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
giuros01b3204e72019-04-01 13:50:22 +01003440
3441 int i = 0;
3442 for(; i <= (K - K0); i += K0)
3443 {
3444 // Supported cases (M0, K0):
3445 // 1,2 - 1,3 - 1,4 - 1,8 - 1,16
3446 // 2,2 - 2,3 - 2,4 - 2,8 - 2,16
3447 // 3,2 - 3,3 - 3,4 - 3,8 - 3,16
3448 // 4,2 - 4,3 - 4,4 - 4,8 - 4,16
3449 // 5,2 - 5,3 - 5,4 - 5,8 - 5,16
3450 // 6,2 - 6,3 - 6,4 - 6,8 - 6,16
3451 // 7,2 - 7,3 - 7,4 - 7,8 - 7,16
3452 // 8,2 - 8,3 - 8,4 - 8,8 - 8,16
3453 // Load values from LHS matrix
3454 LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
3455
3456 // Load values from RHS matrix
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003457 LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zero);
giuros01b3204e72019-04-01 13:50:22 +01003458
3459 RHS_VFMA_M0xN0(0, a, b0, c);
3460 RHS_VFMA_M0xN0(1, a, b1, c);
3461#if K0 > 2
3462 RHS_VFMA_M0xN0(2, a, b2, c);
3463#endif // K0 > 2
3464#if K0 > 3
3465 RHS_VFMA_M0xN0(3, a, b3, c);
3466#endif // K0 > 3
3467#if K0 > 4
3468 RHS_VFMA_M0xN0(4, a, b4, c);
3469 RHS_VFMA_M0xN0(5, a, b5, c);
3470 RHS_VFMA_M0xN0(6, a, b6, c);
3471 RHS_VFMA_M0xN0(7, a, b7, c);
3472#endif // K0 > 4
3473#if K0 > 8
3474 RHS_VFMA_M0xN0(8, a, b8, c);
3475 RHS_VFMA_M0xN0(9, a, b9, c);
Gian Marco Iodice7b9d7ca2019-09-19 16:37:39 +01003476 RHS_VFMA_M0xN0(A, a, bA, c);
3477 RHS_VFMA_M0xN0(B, a, bB, c);
3478 RHS_VFMA_M0xN0(C, a, bC, c);
3479 RHS_VFMA_M0xN0(D, a, bD, c);
3480 RHS_VFMA_M0xN0(E, a, bE, c);
3481 RHS_VFMA_M0xN0(F, a, bF, c);
giuros01b3204e72019-04-01 13:50:22 +01003482#endif // K0 > 8
3483
3484 lhs_offset += K0 * sizeof(DATA_TYPE);
3485 rhs_offset += K0 * rhs_stride_y;
3486 }
3487
3488 // Left-over accumulations
3489 for(; i < K; ++i)
3490 {
3491 // Load values from LHS matrix
3492 VEC_DATA_TYPE(DATA_TYPE, 2)
3493 a0 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 0 * lhs_stride_y + zlhs0));
3494#if M0 > 1
3495 VEC_DATA_TYPE(DATA_TYPE, 2)
3496 a1 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 1 * lhs_stride_y + zlhs1));
3497#endif // M0 > 1
3498#if M0 > 2
3499 VEC_DATA_TYPE(DATA_TYPE, 2)
3500 a2 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 2 * lhs_stride_y + zlhs2));
3501#endif // M0 > 2
3502#if M0 > 3
3503 VEC_DATA_TYPE(DATA_TYPE, 2)
3504 a3 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 3 * lhs_stride_y + zlhs3));
3505#endif // M0 > 3
3506#if M0 > 4
3507 VEC_DATA_TYPE(DATA_TYPE, 2)
3508 a4 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 4 * lhs_stride_y + zlhs4));
3509#endif // M0 > 4
3510#if M0 > 5
3511 VEC_DATA_TYPE(DATA_TYPE, 2)
3512 a5 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 5 * lhs_stride_y + zlhs5));
3513#endif // M0 > 5
3514#if M0 > 6
3515 VEC_DATA_TYPE(DATA_TYPE, 2)
3516 a6 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 6 * lhs_stride_y + zlhs6));
3517#endif // M0 > 6
3518#if M0 > 7
3519 VEC_DATA_TYPE(DATA_TYPE, 2)
3520 a7 = *((__global DATA_TYPE *)(lhs_ptr + lhs_offset + 7 * lhs_stride_y + zlhs7));
3521#endif // M0 > 7
3522
3523 VEC_DATA_TYPE(DATA_TYPE, N0)
3524 b = VLOAD(N0)(0, (__global DATA_TYPE *)(rhs_ptr + rhs_offset + 0 * rhs_stride_y));
3525 RHS_VFMA_M0xN0(0, a, b, c);
3526
3527 lhs_offset += sizeof(DATA_TYPE);
3528 rhs_offset += rhs_stride_y;
3529 }
3530
3531 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)N0 * sizeof(DATA_TYPE)) + (y * (uint)M0 * dst_stride_y);
3532
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003533 REPEAT_VAR_INIT_TO_CONST(M0, uint, zout, 0);
giuros01b3204e72019-04-01 13:50:22 +01003534
3535#if defined(REINTERPRET_OUTPUT_AS_3D)
3536 // The plane (zout) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
3537 CALCULATE_Z_OFFSET(M0, uint, zout, y, HEIGHT_GEMM3D, DEPTH_GEMM3D, dst_cross_plane_pad, dst_stride_y);
3538
3539 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3540 // multiply dst_stride_z by DEPTH_GEMM3D
3541 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3542
3543#else // defined(REINTERPRET_OUTPUT_AS_3D)
3544
3545 // Add offset for batched GEMM
3546 dst_addr += z * dst_stride_z;
3547
3548#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3549
3550 // Multiply by the weight of matrix-matrix product and store the result
giuros01b3204e72019-04-01 13:50:22 +01003551#if defined(ALPHA)
3552 SCALE_BLOCK(M0, DATA_TYPE, c, ALPHA);
3553#endif // defined(ALPHA)
3554
Gian Marco Iodice944170e2019-06-24 14:40:30 +01003555 // Add beta*bias
3556#if defined(BETA)
3557#if defined(BROADCAST_BIAS)
3558 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
3559
3560 LOAD_BLOCK(1, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3561
3562#ifndef UNIT_BETA
3563 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
3564#endif // UNIT_BIAS
3565
3566 // c = c + bias[broadcasted]
3567 ADD_BLOCK_BROADCAST(M0, c, bias0);
3568
3569#else // defined(BROADCAST_BIAS)
3570 __global uchar *bias_addr = bias_ptr + bias_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (get_global_id(1) * (uint)M0 * bias_stride_y) + get_global_id(
3571 2) * bias_stride_z;
3572
3573 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, bias_addr, 0, bias_stride_y, zero);
3574
3575#ifndef UNIT_BETA
3576 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
3577#endif // UNIT_BIAS
3578
3579 // c = c + bias
3580 ADD_BLOCK(M0, c, bias);
3581
3582#endif // defined(BROADCAST_BIAS)
3583#endif // defined(BETA)
3584
Gian Marco Iodiceca1f4602019-07-16 15:46:48 +01003585#if defined(ACTIVATION_TYPE)
3586 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, c, A_VAL, B_VAL);
3587#endif // defined(ACTIVATION_TYPE)
3588
giuros01b3204e72019-04-01 13:50:22 +01003589 // Store output block
3590 STORE_BLOCK(M0, N0, DATA_TYPE, c, dst_addr, dst_stride_y, zout);
3591
3592#undef RHS_BLOCK_SIZE
3593#undef RHS_OFFSET_X
3594#undef RHS_STEP_X
3595}
3596#endif // defined(M0) && defined(N0) && defined(K0) && defined(K) && defined(DATA_TYPE)
3597
Gian Marco36a0a462018-01-12 10:21:40 +00003598#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003599/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00003600 *
Gian Marco19835e52018-01-30 13:35:54 +00003601 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003602 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3603 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3604 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3605 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003606 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003607 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3608 * The activation function is performed after the bias addition
3609 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003610 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3611 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3612 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3613 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3614 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003615 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3616 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3617 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3618 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3619 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3620 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003621 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003622 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3623 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3624 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3625 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3626 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003627 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3628 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3629 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3630 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3631 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3632 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003633 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003634 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003635 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003636 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003637 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003638 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003639 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3640 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003641 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003642 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003643 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003644 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003645__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
3646 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003647#if defined(BETA)
3648 IMAGE_DECLARATION(src2),
3649#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01003650 IMAGE_DECLARATION(dst),
3651 uint src0_stride_z,
3652 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003653#if defined(BETA)
3654 uint src2_stride_z,
3655#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003656 uint dst_stride_z
3657#if defined(REINTERPRET_OUTPUT_AS_3D)
3658 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003659 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003660#endif // REINTERPRET_OUTPUT_AS_3D
3661 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003662{
Gian Marco36a0a462018-01-12 10:21:40 +00003663 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3664 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003665 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003666
Gian Marco36a0a462018-01-12 10:21:40 +00003667 // Offset
3668 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3669 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003670
Gian Marco36a0a462018-01-12 10:21:40 +00003671 // src_addr_a = address of matrix A
3672 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003673 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3674 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3675
3676#if defined(MATRIX_B_DEPTH)
3677 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3678 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3679#else // defined(MATRIX_B_DEPTH)
3680 src1_addr_in_bytes += z * src1_stride_z;
3681#endif // defined(MATRIX_B_DEPTH)
3682
3683 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3684 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003685
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003686 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00003687 __global float *src_end_addr_b = src_addr_b + COLS_B;
3688
3689 src_addr_a += offset_row_a;
3690 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003691
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003692 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003693 float4 c0 = 0.0f;
3694 float4 c1 = 0.0f;
3695 float4 c2 = 0.0f;
3696 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003697
Gian Marco36a0a462018-01-12 10:21:40 +00003698 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003699 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003700 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003701 float4 a0 = vload4(0, src_addr_a);
3702 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003703
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003704 c0 += (float4)a0.s0 * b0;
3705 c1 += (float4)a0.s1 * b0;
3706 c2 += (float4)a0.s2 * b0;
3707 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003708
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003709 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003710 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
3711 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003712
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003713 c0 += (float4)a0.s0 * b0;
3714 c1 += (float4)a0.s1 * b0;
3715 c2 += (float4)a0.s2 * b0;
3716 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003717 }
3718
Gian Marco36a0a462018-01-12 10:21:40 +00003719 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003720 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003721 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00003722 float4 a0 = vload4(0, src_addr_a);
3723 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003724
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003725 c0 += (float4)a0.s0 * b0;
3726 c1 += (float4)a0.s1 * b0;
3727 c2 += (float4)a0.s2 * b0;
3728 c3 += (float4)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003729 }
3730
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003731 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003732 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3733
Gian Marcoae2af742018-02-15 12:35:44 +00003734 // Compute dst address
3735 __global uchar *dst_addr = offset(&dst, 0, 0);
3736
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003737 uint4 zout = 0;
3738
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003739#if defined(REINTERPRET_OUTPUT_AS_3D)
3740 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003741 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003742 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003743 // | |
3744 // | plane0 |
3745 // | |
3746 // |__________________|
3747 // |******************|
3748 // | cross_plane_pad |
3749 // |******************|
3750 // | |
3751 // | plane1 |
3752 // | |
3753 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003754
3755 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003756 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
3757 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003758
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003759 // Add offset due to the cross plane paddings
3760 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003761
3762 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3763 // multiply dst_stride_z by DEPTH_GEMM3D
3764 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003765#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00003766 // Add offset for batched GEMM
3767 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003768#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3769
3770 // Multiply by the weight of matrix-matrix product and store the result
3771#if defined(ALPHA)
3772 SCALE_BLOCK(4, float, c, ALPHA);
3773#endif // defined(ALPHA)
3774
3775 // Add beta*bias
3776#if defined(BETA)
3777 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
3778
3779#if defined(BROADCAST_BIAS)
3780 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
3781
3782 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3783
3784#ifndef UNIT_BETA
3785 SCALE_BLOCK(1, float, bias, BETA);
3786#endif // UNIT_BIAS
3787
3788 // c = c + bias[broadcasted]
3789 ADD_BLOCK_BROADCAST(4, c, bias0);
3790
3791#else // defined(BROADCAST_BIAS)
3792 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
3793 2) * src2_stride_z;
3794
3795 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
3796
3797#ifndef UNIT_BETA
3798 SCALE_BLOCK(4, float, bias, BETA);
3799#endif // UNIT_BIAS
3800
3801 // c = c + bias
3802 ADD_BLOCK(4, c, bias);
3803
3804#endif // defined(BROADCAST_BIAS)
3805#endif // defined(BETA)
3806
3807#if defined(ACTIVATION_TYPE)
3808 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
3809#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00003810
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003811 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003812 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
3813 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
3814 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
3815 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003816}
3817
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003818/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003819 *
Gian Marco19835e52018-01-30 13:35:54 +00003820 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003821 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
3822 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3823 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
3824 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
3825 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003826 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003827 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
3828 * The activation function is performed after the bias addition
3829 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003830 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3831 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3832 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3833 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3834 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003835 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3836 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3837 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3838 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3839 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3840 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003841 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003842 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3843 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3844 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3845 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3846 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003847 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
3848 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
3849 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
3850 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
3851 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
3852 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003853 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003854 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003855 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003856 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00003857 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003858 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003859 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3860 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003861 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003862 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003863 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003864 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003865__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
3866 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003867#if defined(BETA)
3868 IMAGE_DECLARATION(src2),
3869#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00003870 IMAGE_DECLARATION(dst),
3871 uint src0_stride_z,
3872 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003873#if defined(BETA)
3874 uint src2_stride_z,
3875#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003876 uint dst_stride_z
3877#if defined(REINTERPRET_OUTPUT_AS_3D)
3878 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003879 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003880#endif // REINTERPRET_OUTPUT_AS_3D
3881 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003882{
Gian Marco36a0a462018-01-12 10:21:40 +00003883 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
3884 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00003885 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +00003886
3887 // Offset
3888 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
3889 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
3890
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003891 // src_addr_a = address of matrix A
3892 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00003893 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
3894 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
3895
3896#if defined(MATRIX_B_DEPTH)
3897 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3898 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
3899#else // defined(MATRIX_B_DEPTH)
3900 src1_addr_in_bytes += z * src1_stride_z;
3901#endif // defined(MATRIX_B_DEPTH)
3902
3903 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
3904 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003905
Gian Marco36a0a462018-01-12 10:21:40 +00003906 src_addr_a += offset_row_a;
3907 src_addr_b += offset_row_b;
3908
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003909 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003910 float4 c0 = 0.0f;
3911 float4 c1 = 0.0f;
3912 float4 c2 = 0.0f;
3913 float4 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003914
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003915#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
3916
3917 int i = 0;
3918 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003919 {
3920 // Load values from matrix A (interleaved) and matrix B (transposed)
3921 float4 a0 = vload4(0, src_addr_a);
3922 float4 b0 = vload4(0, src_addr_b);
3923
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003924 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3925 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003926
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003927 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3928 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3929 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3930 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003931
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003932 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3933 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3934 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3935 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003936
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003937 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3938 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3939 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3940 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003941
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003942 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3943 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3944 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3945 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003946
3947 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003948 a0 = vload4(0, src_addr_a);
3949 b0 = vload4(0, src_addr_b);
3950
3951 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3952 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003953
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003954 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3955 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3956 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3957 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003958
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003959 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3960 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3961 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3962 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003963
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003964 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3965 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3966 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3967 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003968
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003969 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3970 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3971 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3972 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003973
3974 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003975 a0 = vload4(0, src_addr_a);
3976 b0 = vload4(0, src_addr_b);
3977
3978 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
3979 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
3980
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003981 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
3982 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
3983 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
3984 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003985
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003986 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
3987 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
3988 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
3989 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003990
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003991 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
3992 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
3993 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
3994 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01003995
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01003996 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
3997 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
3998 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
3999 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004000
4001 // Load values from matrix A (interleaved) and matrix B (transposed)
4002 a0 = vload4(0, src_addr_a);
4003 b0 = vload4(0, src_addr_b);
4004
4005 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4006 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004007
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004008 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
4009 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
4010 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
4011 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004012
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004013 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
4014 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
4015 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
4016 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004017
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004018 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
4019 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
4020 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
4021 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004022
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004023 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
4024 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
4025 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
4026 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004027 }
4028
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004029 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004030 {
4031 // Load values from matrix A (interleaved) and matrix B (transposed)
4032 float4 a0 = vload4(0, src_addr_a);
4033 float4 b0 = vload4(0, src_addr_b);
4034
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01004035 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4036 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
4037
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004038 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
4039 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
4040 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
4041 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004042
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004043 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
4044 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
4045 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
4046 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004047
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004048 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
4049 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
4050 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
4051 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004052
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004053 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
4054 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
4055 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
4056 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004057 }
4058
4059 // Compute destination address
4060 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4061
Gian Marcoae2af742018-02-15 12:35:44 +00004062 // Compute dst address
4063 __global uchar *dst_addr = offset(&dst, 0, 0);
4064
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004065 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004066
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004067#if defined(REINTERPRET_OUTPUT_AS_3D)
4068 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004069 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004070 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004071 // | |
4072 // | plane0 |
4073 // | |
4074 // |__________________|
4075 // |******************|
4076 // | cross_plane_pad |
4077 // |******************|
4078 // | |
4079 // | plane1 |
4080 // | |
4081 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004082
4083 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004084 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4085 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004086
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004087 // Add offset due to the cross plane paddings
4088 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004089
4090 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4091 // multiply dst_stride_z by DEPTH_GEMM3D
4092 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004093#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00004094 // Add offset for batched GEMM
4095 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004096#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4097
4098 // Multiply by the weight of matrix-matrix product and store the result
4099#if defined(ALPHA)
4100 SCALE_BLOCK(4, float, c, ALPHA);
4101#endif // defined(ALPHA)
4102
4103 // Add beta*bias
4104#if defined(BETA)
4105 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4106
4107#if defined(BROADCAST_BIAS)
4108 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
4109
4110 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4111
4112#ifndef UNIT_BETA
4113 SCALE_BLOCK(1, float, bias, BETA);
4114#endif // UNIT_BIAS
4115
4116 // c = c + bias[broadcasted]
4117 ADD_BLOCK_BROADCAST(4, c, bias0);
4118
4119#else // defined(BROADCAST_BIAS)
4120 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4121 2) * src2_stride_z;
4122
4123 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
4124
4125#ifndef UNIT_BETA
4126 SCALE_BLOCK(4, float, bias, BETA);
4127#endif // UNIT_BIAS
4128
4129 // c = c + bias
4130 ADD_BLOCK(4, c, bias);
4131
4132#endif // defined(BROADCAST_BIAS)
4133#endif // defined(BETA)
4134
4135#if defined(ACTIVATION_TYPE)
4136 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, c, A_VAL, B_VAL);
4137#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00004138
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004139 // Store 4x4 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004140 vstore4(c0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
4141 vstore4(c1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
4142 vstore4(c2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
4143 vstore4(c3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004144}
4145
Georgios Pinitas84225582018-05-14 12:00:05 +01004146// Undefine local defines
4147#undef COLS_MTX_B
4148
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01004149#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004150/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004151 *
Gian Marco19835e52018-01-30 13:35:54 +00004152 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004153 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
4154 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
4155 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4156 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004157 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004158 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4159 * The activation function is performed after the bias addition
4160 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004161 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4162 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4163 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4164 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4165 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004166 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4167 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4168 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4169 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4170 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4171 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004172 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004173 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4174 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4175 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4176 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4177 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004178 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4179 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4180 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4181 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4182 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4183 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004184 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004185 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00004186 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004187 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +00004188 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004189 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004190 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4191 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004192 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004193 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004194 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004195 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004196__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
4197 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004198#if defined(BETA)
4199 IMAGE_DECLARATION(src2),
4200#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004201 IMAGE_DECLARATION(dst),
4202 uint src0_stride_z,
4203 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004204#if defined(BETA)
4205 uint src2_stride_z,
4206#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004207 uint dst_stride_z
4208#if defined(REINTERPRET_OUTPUT_AS_3D)
4209 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004210 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004211#endif // REINTERPRET_OUTPUT_AS_3D
4212 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004213{
Gian Marco36a0a462018-01-12 10:21:40 +00004214 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
4215 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +00004216 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004217
Gian Marco36a0a462018-01-12 10:21:40 +00004218 // Offset
4219 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
4220 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004221
Gian Marco36a0a462018-01-12 10:21:40 +00004222 // src_addr_a = address of matrix A
4223 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00004224 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
4225 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
4226
4227#if defined(MATRIX_B_DEPTH)
4228 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4229 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
4230#else // defined(MATRIX_B_DEPTH)
4231 src1_addr_in_bytes += z * src1_stride_z;
4232#endif // defined(MATRIX_B_DEPTH)
4233
4234 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
4235 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004236
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004237 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00004238 __global half *src_end_addr_b = src_addr_b + COLS_B;
4239
4240 src_addr_a += offset_row_a;
4241 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004242
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004243 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004244 half8 c0 = 0.0f;
4245 half8 c1 = 0.0f;
4246 half8 c2 = 0.0f;
4247 half8 c3 = 0.0f;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004248
Gian Marco36a0a462018-01-12 10:21:40 +00004249 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004250 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004251 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00004252 half4 a0 = vload4(0, src_addr_a);
4253 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004254
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004255 c0 += (half8)a0.s0 * b0;
4256 c1 += (half8)a0.s1 * b0;
4257 c2 += (half8)a0.s2 * b0;
4258 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004259
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004260 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00004261 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
4262 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004263
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004264 c0 += (half8)a0.s0 * b0;
4265 c1 += (half8)a0.s1 * b0;
4266 c2 += (half8)a0.s2 * b0;
4267 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004268 }
4269
Gian Marco36a0a462018-01-12 10:21:40 +00004270 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004271 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004272 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00004273 half4 a0 = vload4(0, src_addr_a);
4274 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004275
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004276 c0 += (half8)a0.s0 * b0;
4277 c1 += (half8)a0.s1 * b0;
4278 c2 += (half8)a0.s2 * b0;
4279 c3 += (half8)a0.s3 * b0;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004280 }
4281
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004282 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004283 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4284
Gian Marcoae2af742018-02-15 12:35:44 +00004285 // Compute dst address
4286 __global uchar *dst_addr = offset(&dst, 0, 0);
4287
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004288 uint4 zout = 0;
4289
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004290#if defined(REINTERPRET_OUTPUT_AS_3D)
4291 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004292 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004293 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004294 // | |
4295 // | plane0 |
4296 // | |
4297 // |__________________|
4298 // |******************|
4299 // | cross_plane_pad |
4300 // |******************|
4301 // | |
4302 // | plane1 |
4303 // | |
4304 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004305
4306 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004307 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4308 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004309
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004310 // Add offset due to the cross plane paddings
4311 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004312
4313 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4314 // multiply dst_stride_z by DEPTH_GEMM3D
4315 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004316#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00004317 // Add offset for batched GEMM
4318 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004319#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4320
4321 // Multiply by the weight of matrix-matrix product and store the result
4322#if defined(ALPHA)
4323 SCALE_BLOCK(4, half, c, ALPHA);
4324#endif // defined(ALPHA)
4325
4326 // Add beta*bias
4327#if defined(BETA)
4328 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4329
4330#if defined(BROADCAST_BIAS)
4331 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
4332
4333 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4334
4335#ifndef UNIT_BETA
4336 SCALE_BLOCK(1, half, bias, BETA);
4337#endif // UNIT_BIAS
4338
4339 // c = c + bias[broadcasted]
4340 ADD_BLOCK_BROADCAST(4, c, bias0);
4341
4342#else // defined(BROADCAST_BIAS)
4343
4344 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4345 2) * src2_stride_z;
4346
4347 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4348
4349#ifndef UNIT_BETA
4350 SCALE_BLOCK(4, half, bias, BETA);
4351#endif // UNIT_BIAS
4352
4353 // c = c + bias
4354 ADD_BLOCK(4, c, bias);
4355
4356#endif // defined(BROADCAST_BIAS)
4357#endif // defined(BETA)
4358
4359#if defined(ACTIVATION_TYPE)
4360 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
4361#endif // defined(ACTIVATION_TYPE)
Gian Marcoae2af742018-02-15 12:35:44 +00004362
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004363 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004364 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4365 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4366 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4367 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004368}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004369
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004370/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004371 *
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004372 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004373 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
4374 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
4375 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4376 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004377 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004378 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4379 * The activation function is performed after the bias addition
4380 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004381 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4382 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4383 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4384 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4385 *
4386 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4387 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4388 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4389 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4390 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4391 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4392 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4393 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4394 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4395 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4396 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4397 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004398 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4399 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4400 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4401 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4402 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4403 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004404 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4405 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4406 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
4407 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4408 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
4409 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
4410 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4411 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004412 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004413 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
4414 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
4415 */
4416__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
4417 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004418#if defined(BETA)
4419 IMAGE_DECLARATION(src2),
4420#endif // defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004421 IMAGE_DECLARATION(dst),
4422 uint src0_stride_z,
4423 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004424#if defined(BETA)
4425 uint src2_stride_z,
4426#endif //defined(BETA)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004427 uint dst_stride_z
4428#if defined(REINTERPRET_OUTPUT_AS_3D)
4429 ,
4430 uint cross_plane_pad
4431#endif // REINTERPRET_OUTPUT_AS_3D
4432 )
4433{
4434 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
4435 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
4436 int z = get_global_id(2);
4437
4438 // Offset
4439 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
4440 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
4441
4442 // src_addr_a = address of matrix A
4443 // src_addr_b = address of matrix B
4444 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
4445 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
4446
4447#if defined(MATRIX_B_DEPTH)
4448 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4449 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
4450#else // defined(MATRIX_B_DEPTH)
4451 src1_addr_in_bytes += z * src1_stride_z;
4452#endif // defined(MATRIX_B_DEPTH)
4453
4454 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
4455 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
4456
4457 // Compute end row address for matrix B
4458 __global half *src_end_addr_b = src_addr_b + COLS_B;
4459
4460 src_addr_a += offset_row_a;
4461 src_addr_b += offset_row_b;
4462
4463 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004464 float8 c0 = 0.0f;
4465 float8 c1 = 0.0f;
4466 float8 c2 = 0.0f;
4467 float8 c3 = 0.0f;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004468
4469 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
4470 {
4471 // Load values from matrix A (interleaved) and matrix B (transposed)
4472 float4 a0 = convert_float4(vload4(0, src_addr_a));
4473 float8 b0 = convert_float8(vload8(0, src_addr_b));
4474
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004475 c0 += (float8)a0.s0 * b0;
4476 c1 += (float8)a0.s1 * b0;
4477 c2 += (float8)a0.s2 * b0;
4478 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004479
4480 // Load values from matrix A (interleaved) and matrix B (transposed)
4481 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
4482 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
4483
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004484 c0 += (float8)a0.s0 * b0;
4485 c1 += (float8)a0.s1 * b0;
4486 c2 += (float8)a0.s2 * b0;
4487 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004488 }
4489
4490 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
4491 {
4492 // Load values from matrix A (interleaved) and matrix B (transposed)
4493 float4 a0 = convert_float4(vload4(0, src_addr_a));
4494 float8 b0 = convert_float8(vload8(0, src_addr_b));
4495
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004496 c0 += (float8)a0.s0 * b0;
4497 c1 += (float8)a0.s1 * b0;
4498 c2 += (float8)a0.s2 * b0;
4499 c3 += (float8)a0.s3 * b0;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004500 }
4501
4502 // Compute destination address
4503 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4504
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004505 // Compute dst address
4506 __global uchar *dst_addr = offset(&dst, 0, 0);
4507
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004508 uint4 zout = 0;
4509
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004510#if defined(REINTERPRET_OUTPUT_AS_3D)
4511 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
4512 // in order to take into account the presence of possible cross plane paddings
4513 //
4514 // | |
4515 // | plane0 |
4516 // | |
4517 // |__________________|
4518 // |******************|
4519 // | cross_plane_pad |
4520 // |******************|
4521 // | |
4522 // | plane1 |
4523 // | |
4524 // |__________________|
4525
4526 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004527 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4528 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004529
4530 // Add offset due to the cross plane paddings
4531 zout *= (cross_plane_pad * dst_stride_y);
4532
4533 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4534 // multiply dst_stride_z by DEPTH_GEMM3D
4535 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004536#else // defined(REINTERPRET_OUTPUT_AS_3D)
4537 // Add offset for batched GEMM
4538 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004539#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4540
4541 // Multiply by the weight of matrix-matrix product and store the result
4542#if defined(ALPHA)
4543 SCALE_BLOCK(4, float, c, ALPHA);
4544#endif // defined(ALPHA)
4545
4546#if defined(BETA)
4547 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4548
4549#if defined(BROADCAST_BIAS)
4550 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
4551
4552 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4553
4554 float8 bias_f0 = convert_float8(bias0);
4555
4556#ifndef UNIT_BETA
4557 SCALE_BLOCK(1, float, bias_f, BETA);
4558#endif // UNIT_BIAS
4559
4560 // c = c + bias[broadcasted]
4561 ADD_BLOCK_BROADCAST(4, c, bias_f0);
4562
4563#else // defined(BROADCAST_BIAS)
4564 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4565 2) * src2_stride_z;
4566
4567 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4568
4569 float8 bias_f0 = convert_float8(bias0);
4570 float8 bias_f1 = convert_float8(bias1);
4571 float8 bias_f2 = convert_float8(bias2);
4572 float8 bias_f3 = convert_float8(bias3);
4573
4574#ifndef UNIT_BETA
4575 SCALE_BLOCK(4, float, bias_f, BETA);
4576#endif // UNIT_BIAS
4577
4578 // c = c + bias
4579 ADD_BLOCK(4, c, bias_f);
4580
4581#endif // defined(BROADCAST_BIAS)
4582#endif // defined(BETA)
4583
4584 half8 c_h0 = convert_half8(c0);
4585 half8 c_h1 = convert_half8(c1);
4586 half8 c_h2 = convert_half8(c2);
4587 half8 c_h3 = convert_half8(c3);
4588
4589#if defined(ACTIVATION_TYPE)
4590 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c_h, A_VAL, B_VAL);
4591#endif // defined(ACTIVATION_TYPE)
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004592
4593 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004594 vstore8(c_h0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4595 vstore8(c_h1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4596 vstore8(c_h2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4597 vstore8(c_h3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00004598}
4599
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004600/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004601 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004602 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004603 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (e.g. -DMULT_TRANSPOSE1XW_WIDTH=2)
4604 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (e.g. -DMULT_INTERLEAVE4X4_HEIGHT=2)
4605 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4606 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004607 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004608 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4609 * The activation function is performed after the bias addition
4610 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004611 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4612 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4613 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4614 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4615 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004616 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
4617 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4618 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4619 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4620 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4621 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
4622 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
4623 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4624 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4625 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4626 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4627 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004628 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4629 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4630 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4631 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4632 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4633 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004634 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
4635 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4636 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
4637 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4638 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
4639 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004640 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4641 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
4642 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004643 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004644 */
4645__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
4646 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004647#if defined(BETA)
4648 IMAGE_DECLARATION(src2),
4649#endif // defined(BETA)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004650 IMAGE_DECLARATION(dst),
4651 uint src0_stride_z,
4652 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004653#if defined(BETA)
4654 uint src2_stride_z,
4655#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004656 uint dst_stride_z
4657#if defined(REINTERPRET_OUTPUT_AS_3D)
4658 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004659 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004660#endif // REINTERPRET_OUTPUT_AS_3D
4661 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004662{
4663 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
4664 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
4665 int z = get_global_id(2);
4666
4667 // Offset
4668 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
4669 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
4670
4671 // src_addr_a = address of matrix A
4672 // src_addr_b = address of matrix B
4673 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
4674 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
4675
4676#if defined(MATRIX_B_DEPTH)
4677 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
4678 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
4679#else // defined(MATRIX_B_DEPTH)
4680 src1_addr_in_bytes += z * src1_stride_z;
4681#endif // defined(MATRIX_B_DEPTH)
4682
4683 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
4684 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
4685
4686 // Compute end row address for matrix B
4687 __global half *src_end_addr_b = src_addr_b + COLS_B;
4688
4689 src_addr_a += offset_row_a;
4690 src_addr_b += offset_row_b;
4691
4692 // Reset accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004693 half8 c0 = 0.0f;
4694 half8 c1 = 0.0f;
4695 half8 c2 = 0.0f;
4696 half8 c3 = 0.0f;
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004697
4698#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
4699
4700 int i = 0;
4701 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
4702 {
4703#if MULT_INTERLEAVE4X4_HEIGHT == 1
4704 // Load values from matrix A (interleaved) and matrix B (transposed)
4705 half8 a0 = vload8(0, src_addr_a);
4706 half8 b0 = vload8(0, src_addr_b);
4707
4708 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
4709 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4710
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004711 c0 = fma((half8)a0.s0, b0, c0);
4712 c1 = fma((half8)a0.s1, b0, c1);
4713 c2 = fma((half8)a0.s2, b0, c2);
4714 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004715
4716 // Load values from matrix B (transposed)
4717 b0 = vload8(0, src_addr_b);
4718
4719 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4720
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004721 c0 = fma((half8)a0.s4, b0, c0);
4722 c1 = fma((half8)a0.s5, b0, c1);
4723 c2 = fma((half8)a0.s6, b0, c2);
4724 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004725
4726 // Load values from matrix A (interleaved) and matrix B (transposed)
4727 a0 = vload8(0, src_addr_a);
4728 b0 = vload8(0, src_addr_b);
4729
4730 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
4731 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4732
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004733 c0 = fma((half8)a0.s0, b0, c0);
4734 c1 = fma((half8)a0.s1, b0, c1);
4735 c2 = fma((half8)a0.s2, b0, c2);
4736 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004737
4738 // Load values from matrix B (transposed)
4739 b0 = vload8(0, src_addr_b);
4740
4741 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4742
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004743 c0 = fma((half8)a0.s4, b0, c0);
4744 c1 = fma((half8)a0.s5, b0, c1);
4745 c2 = fma((half8)a0.s6, b0, c2);
4746 c3 = fma((half8)a0.s7, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004747#else // MULT_INTERLEAVE4X4_HEIGHT == 1
4748 // Load values from matrix A (interleaved) and matrix B (transposed)
4749 half4 a0 = vload4(0, src_addr_a);
4750 half8 b0 = vload8(0, src_addr_b);
4751
4752 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4753 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4754
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004755 c0 = fma((half8)a0.s0, b0, c0);
4756 c1 = fma((half8)a0.s1, b0, c1);
4757 c2 = fma((half8)a0.s2, b0, c2);
4758 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004759
4760 // Load values from matrix A (interleaved) and matrix B (transposed)
4761 a0 = vload4(0, src_addr_a);
4762 b0 = vload8(0, src_addr_b);
4763
4764 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4765 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4766
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004767 c0 = fma((half8)a0.s0, b0, c0);
4768 c1 = fma((half8)a0.s1, b0, c1);
4769 c2 = fma((half8)a0.s2, b0, c2);
4770 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004771
4772 // Load values from matrix A (interleaved) and matrix B (transposed)
4773 a0 = vload4(0, src_addr_a);
4774 b0 = vload8(0, src_addr_b);
4775
4776 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4777 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4778
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004779 c0 = fma((half8)a0.s0, b0, c0);
4780 c1 = fma((half8)a0.s1, b0, c1);
4781 c2 = fma((half8)a0.s2, b0, c2);
4782 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004783
4784 // Load values from matrix A (interleaved) and matrix B (transposed)
4785 a0 = vload4(0, src_addr_a);
4786 b0 = vload8(0, src_addr_b);
4787
4788 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4789 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4790
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004791 c0 = fma((half8)a0.s0, b0, c0);
4792 c1 = fma((half8)a0.s1, b0, c1);
4793 c2 = fma((half8)a0.s2, b0, c2);
4794 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004795#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
4796 }
4797
4798 for(; i < (int)(COLS_MTX_B); ++i)
4799 {
4800 // Load values from matrix A (interleaved) and matrix B (transposed)
4801 half4 a0 = vload4(0, src_addr_a);
4802 half8 b0 = vload8(0, src_addr_b);
4803
4804 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
4805 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
4806
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004807 c0 = fma((half8)a0.s0, b0, c0);
4808 c1 = fma((half8)a0.s1, b0, c1);
4809 c2 = fma((half8)a0.s2, b0, c2);
4810 c3 = fma((half8)a0.s3, b0, c3);
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004811 }
4812
4813 // Compute destination address
4814 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
4815
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004816 // Compute dst address
4817 __global uchar *dst_addr = offset(&dst, 0, 0);
4818
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004819 uint4 zout = 0;
4820
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004821#if defined(REINTERPRET_OUTPUT_AS_3D)
4822 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004823 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004824 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004825 // | |
4826 // | plane0 |
4827 // | |
4828 // |__________________|
4829 // |******************|
4830 // | cross_plane_pad |
4831 // |******************|
4832 // | |
4833 // | plane1 |
4834 // | |
4835 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004836
4837 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004838 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
4839 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004840
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01004841 // Add offset due to the cross plane paddings
4842 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004843
4844 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
4845 // multiply dst_stride_z by DEPTH_GEMM3D
4846 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004847#else // defined(REINTERPRET_OUTPUT_AS_3D)
4848 // Add offset for batched GEMM
4849 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004850#endif // defined(REINTERPRET_OUTPUT_AS_3D)
4851
4852 // Multiply by the weight of matrix-matrix product and store the result
4853#if defined(ALPHA)
4854 SCALE_BLOCK(4, half, c, ALPHA);
4855#endif // defined(ALPHA)
4856
4857 // Add beta*bias
4858#if defined(BETA)
4859 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
4860
4861#if defined(BROADCAST_BIAS)
4862 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
4863
4864 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4865
4866#ifndef UNIT_BETA
4867 SCALE_BLOCK(1, half, bias, BETA);
4868#endif // UNIT_BIAS
4869
4870 // c = c + bias[broadcasted]
4871 ADD_BLOCK_BROADCAST(4, c, bias0);
4872
4873#else // defined(BROADCAST_BIAS)
4874 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
4875 2) * src2_stride_z;
4876
4877 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
4878
4879#ifndef UNIT_BETA
4880 SCALE_BLOCK(4, half, bias, BETA);
4881#endif // UNIT_BIAS
4882
4883 // c = c + bias
4884 ADD_BLOCK(4, c, bias);
4885
4886#endif // defined(BROADCAST_BIAS)
4887#endif // defined(BETA)
4888
4889#if defined(ACTIVATION_TYPE)
4890 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, c, A_VAL, B_VAL);
4891#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004892
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004893 // Store 4x8 block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004894 vstore8(c0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
4895 vstore8(c1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
4896 vstore8(c2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
4897 vstore8(c3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01004898}
Georgios Pinitas84225582018-05-14 12:00:05 +01004899
4900// Undefine local defines
4901#undef COLS_MTX_B
4902
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01004903#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004904
Gian Marco36a0a462018-01-12 10:21:40 +00004905#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004906
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004907#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
4908#if defined(DATA_TYPE)
4909#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00004910/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
4911 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004912 * @note This OpenCL kernel works with floating point data types (F16/F32)
4913 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
4914 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00004915 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004916 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
4917 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004918 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004919 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
4920 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004921 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
4922 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004923 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
4924 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
4925 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
4926 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
4927 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004928 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004929 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
4930 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4931 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
4932 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4933 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004934 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004935 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
4936 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
4937 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
4938 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
4939 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004940 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
4941 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
4942 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
4943 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
4944 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
4945 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01004946 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004947 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
4948 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
4949 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
4950 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
4951 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004952 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
4953 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004954 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004955 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004956 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
4957 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004958 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004959__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
4960 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004961#if defined(BETA)
4962 IMAGE_DECLARATION(src2),
4963#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00004964 IMAGE_DECLARATION(dst),
4965 uint src0_stride_z,
4966 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01004967#if defined(BETA)
4968 uint src2_stride_z,
4969#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004970 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004971#if defined(REINTERPRET_INPUT_AS_3D)
4972 ,
4973 uint src_cross_plane_pad
4974#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004975#if defined(REINTERPRET_OUTPUT_AS_3D)
4976 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004977 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00004978#endif // REINTERPRET_OUTPUT_AS_3D
4979 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004980{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004981 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004982
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004983 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004984 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004985
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004986 // Update address for the matrix A
4987 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004988
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01004989 // Update address for the matrix B
4990 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01004991
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01004992#if defined(REINTERPRET_INPUT_AS_3D)
4993 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
4994 // in order to take into account the presence of possible cross plane paddings
4995 //
4996 // | |
4997 // | plane0 |
4998 // | |
4999 // |__________________|
5000 // |******************|
5001 // | cross_plane_pad |
5002 // |******************|
5003 // | |
5004 // | plane1 |
5005 // | |
5006 // |__________________|
5007
5008 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5009 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5010 zin = min(DEPTH_GEMM3D - 1, zin);
5011
5012 // Add offset due to the cross plane paddings
5013 zin *= (src_cross_plane_pad * src0_stride_y);
5014
5015 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5016 // multiply src0_stride_z by DEPTH_GEMM3D
5017 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5018
5019#else // defined(REINTERPRET_INPUT_AS_3D)
5020
Gian Marcoae2af742018-02-15 12:35:44 +00005021 // Add offset for batched GEMM
5022 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005023
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005024#endif // defined(REINTERPRET_INPUT_AS_3D)
5025
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005026#if defined(MATRIX_B_DEPTH)
5027 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5028 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5029#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005030 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005031#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005032
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005033 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
5034
5035 VECTOR_TYPE acc0 = 0.0f;
5036#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5037 VECTOR_TYPE acc1 = 0.0f;
5038#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5039#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5040 VECTOR_TYPE acc2 = 0.0f;
5041#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5042#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5043 VECTOR_TYPE acc3 = 0.0f;
5044#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5045
Georgios Pinitas96880cf2017-10-20 18:52:20 +01005046 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005047 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005048#if defined(REINTERPRET_INPUT_AS_3D)
5049 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01005050 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5051#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005052 // Load values from matrix A
5053 VEC_DATA_TYPE(DATA_TYPE, 2)
5054 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5055#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5056 VEC_DATA_TYPE(DATA_TYPE, 2)
5057 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5058#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5059#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5060 VEC_DATA_TYPE(DATA_TYPE, 2)
5061 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5062#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5063#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5064 VEC_DATA_TYPE(DATA_TYPE, 2)
5065 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5066#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005067#endif // defined(REINTERPRET_INPUT_AS_3D)
5068
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005069 // Load values from matrix B
5070 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
5071 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005072
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005073 // Accumulate
5074 acc0 += b0 * (VECTOR_TYPE)a0.s0;
5075 acc0 += b1 * (VECTOR_TYPE)a0.s1;
5076#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5077 acc1 += b0 * (VECTOR_TYPE)a1.s0;
5078 acc1 += b1 * (VECTOR_TYPE)a1.s1;
5079#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5080#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5081 acc2 += b0 * (VECTOR_TYPE)a2.s0;
5082 acc2 += b1 * (VECTOR_TYPE)a2.s1;
5083#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5084#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5085 acc3 += b0 * (VECTOR_TYPE)a3.s0;
5086 acc3 += b1 * (VECTOR_TYPE)a3.s1;
5087#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005088 }
5089
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005090 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005091 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005092#if defined(REINTERPRET_INPUT_AS_3D)
5093 // Load values from matrix A
5094 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5095#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5096 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5097#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5098#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5099 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5100#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5101#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5102 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5103#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5104#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005105 // Load values from matrix A
5106 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5107#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5108 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5109#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5110#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5111 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5112#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5113#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5114 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5115#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005116#endif // defined(REINTERPRET_INPUT_AS_3D)
5117
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005118 // Load values from matrix B
5119 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005120
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005121 // Accumulate
5122 acc0 += b0 * (VECTOR_TYPE)a0;
5123#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5124 acc1 += b0 * (VECTOR_TYPE)a1;
5125#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5126#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5127 acc2 += b0 * (VECTOR_TYPE)a2;
5128#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5129#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5130 acc3 += b0 * (VECTOR_TYPE)a3;
5131#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005132 }
5133
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005134 int z = get_global_id(2);
5135
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005136 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005137 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5138
Gian Marcoae2af742018-02-15 12:35:44 +00005139 // Compute dst address
5140 __global uchar *dst_addr = offset(&dst, 0, 0);
5141
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005142 uint4 zout = 0;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005143
5144#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005145
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005146 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005147 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005148 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005149 // | |
5150 // | plane0 |
5151 // | |
5152 // |__________________|
5153 // |******************|
5154 // | cross_plane_pad |
5155 // |******************|
5156 // | |
5157 // | plane1 |
5158 // | |
5159 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005160
5161 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005162 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5163 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005164
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005165 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005166 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005167
5168 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5169 // multiply dst_stride_z by DEPTH_GEMM3D
5170 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005171#else // defined(REINTERPRET_OUTPUT_AS_3D)
5172 // Add offset for batched GEMM
5173 dst_addr += z * dst_stride_z;
5174#endif // defined(REINTERPRET_OUTPUT_AS_3D)
5175
5176 // Multiply by the weight of matrix-matrix product and store the result
5177#if defined(ALPHA)
5178 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, acc, ALPHA);
5179#endif // defined(ALPHA)
5180
5181 // Add beta*bias
5182#if defined(BETA)
5183 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5184
5185#if defined(BROADCAST_BIAS)
5186 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE));
5187
5188 LOAD_BLOCK(1, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
5189
5190#ifndef UNIT_BETA
5191 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
5192#endif // UNIT_BIAS
5193
5194 // c = c + bias[broadcasted]
5195 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5196
5197#else // defined(BROADCAST_BIAS)
5198 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)NUM_ELEMS_PROCESSED_PER_THREAD_X * sizeof(DATA_TYPE)) + (get_global_id(1) *
5199 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5200
5201 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
5202
5203#ifndef UNIT_BETA
5204 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, DATA_TYPE, bias, BETA);
5205#endif // UNIT_BIAS
5206
5207 // c = c + bias
5208 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5209
5210#endif // defined(BROADCAST_BIAS)
5211#endif // defined(BETA)
5212
5213#if defined(ACTIVATION_TYPE)
5214 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, DATA_TYPE, acc, A_VAL, B_VAL);
5215#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005216
5217 // Store output block
Usama Arif0681e3b2019-04-25 14:28:07 +01005218 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, NUM_ELEMS_PROCESSED_PER_THREAD_X, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01005219}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01005220#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01005221
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01005222/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005223 *
5224 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
5225 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5226 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
5227 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5228 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005229 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5230 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005231 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005232 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5233 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005234 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5235 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005236 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5237 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5238 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5239 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5240 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005241 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005242 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5243 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5244 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5245 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5246 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5247 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5248 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5249 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5250 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5251 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5252 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005253 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5254 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5255 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5256 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5257 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5258 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005259 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5260 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5261 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5262 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5263 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5264 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005265 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5266 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005267 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005268 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005269 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5270 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005271 */
5272__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
5273 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005274#if defined(BETA)
5275 IMAGE_DECLARATION(src2),
5276#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00005277 IMAGE_DECLARATION(dst),
5278 uint src0_stride_z,
5279 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005280#if defined(BETA)
5281 uint src2_stride_z,
5282#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005283 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005284#if defined(REINTERPRET_INPUT_AS_3D)
5285 ,
5286 uint src_cross_plane_pad
5287#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005288#if defined(REINTERPRET_OUTPUT_AS_3D)
5289 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005290 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005291#endif // REINTERPRET_OUTPUT_AS_3D
5292 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005293{
5294 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5295
5296 // Compute starting address for matrix A and matrix B
5297 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5298
5299 // Update address for matrix A
5300 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5301
5302 // Update address for matrix B
5303 src_addr.s1 += idx * sizeof(float);
5304
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005305#if defined(REINTERPRET_INPUT_AS_3D)
5306 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5307 // in order to take into account the presence of possible cross plane paddings
5308 //
5309 // | |
5310 // | plane0 |
5311 // | |
5312 // |__________________|
5313 // |******************|
5314 // | cross_plane_pad |
5315 // |******************|
5316 // | |
5317 // | plane1 |
5318 // | |
5319 // |__________________|
5320
5321 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5322 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5323 zin = min(DEPTH_GEMM3D - 1, zin);
5324
5325 // Add offset due to the cross plane paddings
5326 zin *= (src_cross_plane_pad * src0_stride_y);
5327
5328 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5329 // multiply src0_stride_z by DEPTH_GEMM3D
5330 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5331
5332#else // defined(REINTERPRET_INPUT_AS_3D)
5333
Gian Marcoae2af742018-02-15 12:35:44 +00005334 // Add offset for batched GEMM
5335 src_addr.s0 += get_global_id(2) * src0_stride_z;
5336
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005337#endif // defined(REINTERPRET_INPUT_AS_3D)
5338
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005339#if defined(MATRIX_B_DEPTH)
5340 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5341 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5342#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005343 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005344#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005345
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005346 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005347 float4 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005348
5349#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005350 float4 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005351#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5352
5353#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005354 float4 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005355#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5356
5357#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005358 float4 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005359#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5360
5361 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005362 int i = 0;
5363 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005364 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005365#if defined(REINTERPRET_INPUT_AS_3D)
5366 // Load values from matrix A and matrix B
Usama Arif0681e3b2019-04-25 14:28:07 +01005367 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
5368#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005369 // Load values from matrix A and matrix B
5370 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005371#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005372 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005373#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5374#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005375 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005376#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5377#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005378 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005379#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005380#endif // defined(REINTERPRET_INPUT_AS_3D)
5381
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005382 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5383 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005384
5385 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005386 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
5387 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
5388 acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
5389 acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005390
5391#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005392
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005393 acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
5394 acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
5395 acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
5396 acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005397
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005398#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5399#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005400
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005401 acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
5402 acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
5403 acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
5404 acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005405
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005406#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5407#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005408
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005409 acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
5410 acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
5411 acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
5412 acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005413#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005414
5415 // Load values from matrix A and matrix B
5416 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5417 src_addr.s1 += src1_stride_y;
5418
5419 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005420 acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
5421 acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
5422 acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
5423 acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005424
5425#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5426
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005427 acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
5428 acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
5429 acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
5430 acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005431
5432#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5433#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5434
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005435 acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
5436 acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
5437 acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
5438 acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005439
5440#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5441#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5442
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005443 acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
5444 acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
5445 acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
5446 acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005447#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5448
5449 // Load values from matrix A and matrix B
5450 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5451 src_addr.s1 += src1_stride_y;
5452
5453 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005454 acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
5455 acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
5456 acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
5457 acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005458
5459#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5460
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005461 acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
5462 acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
5463 acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
5464 acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005465
5466#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5467#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5468
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005469 acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
5470 acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
5471 acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
5472 acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005473
5474#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5475#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5476
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005477 acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
5478 acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
5479 acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
5480 acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005481#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5482
5483 // Load values from matrix A and matrix B
5484 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
5485 src_addr.s1 += src1_stride_y;
5486
5487 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005488 acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
5489 acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
5490 acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
5491 acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005492
5493#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5494
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005495 acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
5496 acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
5497 acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
5498 acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005499
5500#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5501#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5502
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005503 acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
5504 acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
5505 acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
5506 acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005507
5508#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5509#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5510
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005511 acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
5512 acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
5513 acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
5514 acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005515#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5516
5517 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005518 }
5519
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005520 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005521 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005522#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005523 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005524 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5525#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5526 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5527#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5528#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5529 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5530#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5531#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5532 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5533#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5534#else // defined(REINTERPRET_INPUT_AS_3D)
5535 // Load values from matrix A
5536 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005537#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5538 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5539#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5540#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5541 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5542#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5543#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5544 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5545#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005546#endif // defined(REINTERPRET_INPUT_AS_3D)
5547
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005548 // Load values from matrix B
5549 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005550 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005551
5552 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005553 acc0.s0 = fma(a0, b0.s0, acc0.s0);
5554 acc0.s1 = fma(a0, b0.s1, acc0.s1);
5555 acc0.s2 = fma(a0, b0.s2, acc0.s2);
5556 acc0.s3 = fma(a0, b0.s3, acc0.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005557#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005558 acc1.s0 = fma(a1, b0.s0, acc1.s0);
5559 acc1.s1 = fma(a1, b0.s1, acc1.s1);
5560 acc1.s2 = fma(a1, b0.s2, acc1.s2);
5561 acc1.s3 = fma(a1, b0.s3, acc1.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005562#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5563#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005564 acc2.s0 = fma(a2, b0.s0, acc2.s0);
5565 acc2.s1 = fma(a2, b0.s1, acc2.s1);
5566 acc2.s2 = fma(a2, b0.s2, acc2.s2);
5567 acc2.s3 = fma(a2, b0.s3, acc2.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005568#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5569#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005570 acc3.s0 = fma(a3, b0.s0, acc3.s0);
5571 acc3.s1 = fma(a3, b0.s1, acc3.s1);
5572 acc3.s2 = fma(a3, b0.s2, acc3.s2);
5573 acc3.s3 = fma(a3, b0.s3, acc3.s3);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005574#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005575
5576 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005577 }
5578
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005579 int z = get_global_id(2);
5580
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005581 // Compute destination address
5582 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5583
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005584 // Compute dst address
5585 __global uchar *dst_addr = offset(&dst, 0, 0);
5586
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005587 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005588
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005589#if defined(REINTERPRET_OUTPUT_AS_3D)
5590 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005591 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005592 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005593 // | |
5594 // | plane0 |
5595 // | |
5596 // |__________________|
5597 // |******************|
5598 // | cross_plane_pad |
5599 // |******************|
5600 // | |
5601 // | plane1 |
5602 // | |
5603 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005604
5605 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005606 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5607 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005608
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01005609 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005610 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005611
5612 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5613 // multiply dst_stride_z by DEPTH_GEMM3D
5614 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005615#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005616 // Add offset for batched GEMM
5617 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005618#endif // defined(REINTERPRET_OUTPUT_AS_3D)
5619
5620 // Multiply by the weight of matrix-matrix product and store the result
5621#if defined(ALPHA)
5622 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
5623#endif // defined(ALPHA)
5624
5625 // Add beta*bias
5626#if defined(BETA)
5627 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
5628
5629#if defined(BROADCAST_BIAS)
5630 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
5631
5632 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
5633
5634#ifndef UNIT_BETA
5635 SCALE_BLOCK(1, float, bias, BETA);
5636#endif // UNIT_BIAS
5637
5638 // acc = acc + bias[broadcasted]
5639 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
5640
5641#else // defined(BROADCAST_BIAS)
5642 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) *
5643 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
5644
5645 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
5646
5647#ifndef UNIT_BETA
5648 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
5649#endif // UNIT_BIAS
5650
5651 // acc = acc + bias
5652 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
5653
5654#endif // defined(BROADCAST_BIAS)
5655#endif // defined(BETA)
5656
5657#if defined(ACTIVATION_TYPE)
5658 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
5659#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005660
5661 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005662 vstore4(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005663#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005664 vstore4(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005665#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5666#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005667 vstore4(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005668#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5669#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005670 vstore4(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005671#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005672}
5673
5674/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
5675 *
5676 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
5677 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
5678 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
5679 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
5680 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
5681 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005682 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
5683 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005684 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005685 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
5686 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005687 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
5688 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005689 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
5690 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
5691 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
5692 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
5693 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005694 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005695 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
5696 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5697 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
5698 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5699 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
5700 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
5701 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
5702 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
5703 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
5704 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
5705 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005706 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
5707 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
5708 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
5709 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
5710 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
5711 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005712 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
5713 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
5714 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
5715 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
5716 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
5717 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005718 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
5719 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005720 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005721 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005722 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
5723 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005724 */
5725__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
5726 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005727#if defined(BETA)
5728 IMAGE_DECLARATION(src2),
5729#endif // defined(BETA)
Gian Marcoae2af742018-02-15 12:35:44 +00005730 IMAGE_DECLARATION(dst),
5731 uint src0_stride_z,
5732 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005733#if defined(BETA)
5734 uint src2_stride_z,
5735#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005736 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005737#if defined(REINTERPRET_INPUT_AS_3D)
5738 ,
5739 uint src_cross_plane_pad
5740#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005741#if defined(REINTERPRET_OUTPUT_AS_3D)
5742 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005743 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005744#endif // REINTERPRET_OUTPUT_AS_3D
5745 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005746{
5747 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5748 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
5749
5750 // Compute starting address for matrix A and Matrix B
5751 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
5752
5753 // Update address for the matrix A
5754 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
5755
5756 // Update address for the matrix B
5757 src_addr.s1 += idx * sizeof(float);
5758
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005759#if defined(REINTERPRET_INPUT_AS_3D)
5760 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
5761 // in order to take into account the presence of possible cross plane paddings
5762 //
5763 // | |
5764 // | plane0 |
5765 // | |
5766 // |__________________|
5767 // |******************|
5768 // | cross_plane_pad |
5769 // |******************|
5770 // | |
5771 // | plane1 |
5772 // | |
5773 // |__________________|
5774
5775 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
5776 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
5777 zin = min(DEPTH_GEMM3D - 1, zin);
5778
5779 // Add offset due to the cross plane paddings
5780 zin *= (src_cross_plane_pad * src0_stride_y);
5781
5782 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
5783 // multiply src0_stride_z by DEPTH_GEMM3D
5784 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
5785
5786#else // defined(REINTERPRET_INPUT_AS_3D)
5787
Gian Marcoae2af742018-02-15 12:35:44 +00005788 // Add offset for batched GEMM
5789 src_addr.s0 += get_global_id(2) * src0_stride_z;
5790
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005791#endif // defined(REINTERPRET_INPUT_AS_3D)
5792
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005793#if defined(MATRIX_B_DEPTH)
5794 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
5795 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
5796#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005797 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00005798#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00005799
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005800 // Initialize accumulators
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005801 float2 acc0 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005802#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005803 float2 acc1 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005804#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5805#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005806 float2 acc2 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005807#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5808#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005809 float2 acc3 = 0.0f;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005810#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5811
5812 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005813 int i = 0;
5814 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005815 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005816#if defined(REINTERPRET_INPUT_AS_3D)
5817 // Load values from matrix A
5818 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
5819#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005820 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005821 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005822#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005823
5824 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005825 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5826 src_addr.s1 += src1_stride_y;
5827 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5828 src_addr.s1 += src1_stride_y;
5829 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5830 src_addr.s1 += src1_stride_y;
5831 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5832 src_addr.s1 += src1_stride_y;
5833 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5834 src_addr.s1 += src1_stride_y;
5835 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5836 src_addr.s1 += src1_stride_y;
5837 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5838 src_addr.s1 += src1_stride_y;
5839 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
5840 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005841
5842 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005843 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
5844 acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
5845 acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
5846 acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
5847 acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
5848 acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
5849 acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
5850 acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005851
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005852 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
5853 acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
5854 acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
5855 acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
5856 acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
5857 acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
5858 acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
5859 acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005860
5861#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005862#if defined(REINTERPRET_INPUT_AS_3D)
5863 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5864#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005865 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005866#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005867 acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
5868 acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
5869 acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
5870 acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
5871 acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
5872 acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
5873 acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
5874 acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005875
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005876 acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
5877 acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
5878 acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
5879 acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
5880 acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
5881 acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
5882 acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
5883 acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005884#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5885#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005886#if defined(REINTERPRET_INPUT_AS_3D)
5887 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5888#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005889 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005890#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005891 acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
5892 acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
5893 acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
5894 acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
5895 acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
5896 acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
5897 acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
5898 acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005899
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005900 acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
5901 acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
5902 acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
5903 acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
5904 acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
5905 acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
5906 acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
5907 acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005908#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5909#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005910#if defined(REINTERPRET_INPUT_AS_3D)
5911 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5912#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005913 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005914#endif // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005915 acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
5916 acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
5917 acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
5918 acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
5919 acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
5920 acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
5921 acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
5922 acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005923
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005924 acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
5925 acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
5926 acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
5927 acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
5928 acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
5929 acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
5930 acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
5931 acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005932#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005933
5934 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005935 }
5936 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005937 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005938 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005939#if defined(REINTERPRET_INPUT_AS_3D)
5940 // Load values from matrix A
5941 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
5942#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5943 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
5944#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5945#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5946 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
5947#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5948#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5949 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
5950#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5951#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005952 // Load values from matrix A
5953 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
5954#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5955 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
5956#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5957#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5958 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
5959#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5960#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
5961 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
5962#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01005963#endif // defined(REINTERPRET_INPUT_AS_3D)
5964
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005965 // Load values from matrix B
5966 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005967 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005968
5969 // Multiply and accumulate
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005970 acc0.s0 = fma(a0, b0.s0, acc0.s0);
5971 acc0.s1 = fma(a0, b0.s1, acc0.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005972#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005973 acc1.s0 = fma(a1, b0.s0, acc1.s0);
5974 acc1.s1 = fma(a1, b0.s1, acc1.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005975#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
5976#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005977 acc2.s0 = fma(a2, b0.s0, acc2.s0);
5978 acc2.s1 = fma(a2, b0.s1, acc2.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005979#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
5980#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005981 acc3.s0 = fma(a3, b0.s0, acc3.s0);
5982 acc3.s1 = fma(a3, b0.s1, acc3.s1);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005983#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01005984
5985 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005986 }
5987
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005988 int z = get_global_id(2);
5989
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00005990 // Compute destination address
5991 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
5992
Gian Marcoae2af742018-02-15 12:35:44 +00005993 // Compute dst address
5994 __global uchar *dst_addr = offset(&dst, 0, 0);
5995
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005996 uint4 zout = 0;
Michele Di Giorgioebc3a902018-11-16 16:04:25 +00005997
Isabella Gottardi8e74f442018-03-01 16:42:00 +00005998#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01005999
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006000 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006001 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006002 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006003 // | |
6004 // | plane0 |
6005 // | |
6006 // |__________________|
6007 // |******************|
6008 // | cross_plane_pad |
6009 // |******************|
6010 // | |
6011 // | plane1 |
6012 // | |
6013 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00006014
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006015 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006016 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6017 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006018
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006019 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006020 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006021
6022 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6023 // multiply dst_stride_z by DEPTH_GEMM3D
6024 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006025#else // defined(REINTERPRET_OUTPUT_AS_3D)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006026 // Add offset for batched GEMM
6027 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006028#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6029
6030 // Multiply by the weight of matrix-matrix product and store the result
6031#if defined(ALPHA)
6032 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
6033#endif // defined(ALPHA)
6034
6035 // Add beta*bias
6036#if defined(BETA)
6037 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
6038
6039#if defined(BROADCAST_BIAS)
6040 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
6041
6042 LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
6043
6044#ifndef UNIT_BETA
6045 SCALE_BLOCK(1, float, bias, BETA);
6046#endif // UNIT_BIAS
6047
6048 // acc = acc + bias[broadcasted]
6049 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
6050
6051#else // defined(BROADCAST_BIAS)
6052 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (get_global_id(1) *
6053 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
6054
6055 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
6056
6057#ifndef UNIT_BETA
6058 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias, BETA);
6059#endif // UNIT_BIAS
6060
6061 // acc = acc + bias
6062 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
6063
6064#endif // defined(BROADCAST_BIAS)
6065#endif // defined(BETA)
6066
6067#if defined(ACTIVATION_TYPE)
6068 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, float, acc, A_VAL, B_VAL);
6069#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006070
6071 // Store the output block
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006072 vstore2(acc0, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006073#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006074 vstore2(acc1, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006075#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6076#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006077 vstore2(acc2, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006078#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6079#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006080 vstore2(acc3, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006081#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006082}
6083
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01006084#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006085/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
6086 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006087 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
6088 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
6089 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
6090 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
6091 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006092 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
6093 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006094 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006095 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
6096 * The activation function is performed after the bias addition
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006097 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
6098 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
6099 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
6100 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
6101 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
6102 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
6103 *
6104 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
6105 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6106 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6107 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6108 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6109 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
6110 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
6111 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6112 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6113 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6114 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6115 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006116 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
6117 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
6118 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
6119 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
6120 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
6121 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006122 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
6123 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6124 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6125 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6126 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6127 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6128 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
6129 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006130 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006131 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6132 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
6133 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
6134 */
6135__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
6136 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006137#if defined(BETA)
6138 IMAGE_DECLARATION(src2),
6139#endif // defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006140 IMAGE_DECLARATION(dst),
6141 uint src0_stride_z,
6142 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006143#if defined(BETA)
6144 uint src2_stride_z,
6145#endif //defined(BETA)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006146 uint dst_stride_z
6147#if defined(REINTERPRET_INPUT_AS_3D)
6148 ,
6149 uint src_cross_plane_pad
6150#endif // REINTERPRET_INPUT_AS_3D
6151#if defined(REINTERPRET_OUTPUT_AS_3D)
6152 ,
6153 uint dst_cross_plane_pad
6154#endif // REINTERPRET_OUTPUT_AS_3D
6155 )
6156{
6157 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
6158
6159 // Compute starting address for matrix A and Matrix B
6160 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
6161
6162 // Update address for the matrix A
6163 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
6164
6165 // Update address for the matrix B
6166 src_addr.s1 += idx * sizeof(half);
6167
6168#if defined(REINTERPRET_INPUT_AS_3D)
6169 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
6170 // in order to take into account the presence of possible cross plane paddings
6171 //
6172 // | |
6173 // | plane0 |
6174 // | |
6175 // |__________________|
6176 // |******************|
6177 // | cross_plane_pad |
6178 // |******************|
6179 // | |
6180 // | plane1 |
6181 // | |
6182 // |__________________|
6183
6184 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
6185 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6186 zin = min(DEPTH_GEMM3D - 1, zin);
6187
6188 // Add offset due to the cross plane paddings
6189 zin *= (src_cross_plane_pad * src0_stride_y);
6190
6191 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6192 // multiply src0_stride_z by DEPTH_GEMM3D
6193 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
6194
6195#else // defined(REINTERPRET_INPUT_AS_3D)
6196
6197 // Add offset for batched GEMM
6198 src_addr.s0 += get_global_id(2) * src0_stride_z;
6199
6200#endif // defined(REINTERPRET_INPUT_AS_3D)
6201
6202#if defined(MATRIX_B_DEPTH)
6203 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
6204 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
6205#else // defined(MATRIX_B_DEPTH)
6206 src_addr.s1 += get_global_id(2) * src1_stride_z;
6207#endif // defined(MATRIX_B_DEPTH)
6208
6209 float8 acc0 = 0.0h;
6210#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6211 float8 acc1 = 0.0h;
6212#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6213#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6214 float8 acc2 = 0.0h;
6215#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6216#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6217 float8 acc3 = 0.0h;
6218#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6219
6220 int i = 0;
6221 for(; i <= ((int)COLS_A - 4); i += 4)
6222 {
6223#if defined(REINTERPRET_INPUT_AS_3D)
6224 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01006225 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
6226#else // defined(REINTERPRET_INPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006227 // Load values from matrix A
6228 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6229#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6230 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6231#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6232#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6233 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6234#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6235#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6236 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6237#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6238#endif // defined(REINTERPRET_INPUT_AS_3D)
6239
6240 // Load values from matrix B
6241 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6242 src_addr.s1 += src1_stride_y;
6243
6244 // Accumulate
6245 acc0 = fma(b0, (float8)a0.s0, acc0);
6246#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6247 acc1 = fma(b0, (float8)a1.s0, acc1);
6248#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6249#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6250 acc2 = fma(b0, (float8)a2.s0, acc2);
6251#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6252#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6253 acc3 = fma(b0, (float8)a3.s0, acc3);
6254#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6255
6256 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6257 src_addr.s1 += src1_stride_y;
6258 acc0 = fma(b0, (float8)a0.s1, acc0);
6259#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6260 acc1 = fma(b0, (float8)a1.s1, acc1);
6261#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6262#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6263 acc2 = fma(b0, (float8)a2.s1, acc2);
6264#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6265#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6266 acc3 = fma(b0, (float8)a3.s1, acc3);
6267#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6268
6269 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6270 src_addr.s1 += src1_stride_y;
6271 acc0 = fma(b0, (float8)a0.s2, acc0);
6272#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6273 acc1 = fma(b0, (float8)a1.s2, acc1);
6274#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6275#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6276 acc2 = fma(b0, (float8)a2.s2, acc2);
6277#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6278#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6279 acc3 = fma(b0, (float8)a3.s2, acc3);
6280#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6281
6282 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6283 src_addr.s1 += src1_stride_y;
6284 acc0 = fma(b0, (float8)a0.s3, acc0);
6285#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6286 acc1 = fma(b0, (float8)a1.s3, acc1);
6287#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6288#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6289 acc2 = fma(b0, (float8)a2.s3, acc2);
6290#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6291#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6292 acc3 = fma(b0, (float8)a3.s3, acc3);
6293#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6294
6295 src_addr.s0 += 4 * sizeof(half);
6296 }
6297
6298 for(; i < (int)COLS_A; ++i)
6299 {
6300#if defined(REINTERPRET_INPUT_AS_3D)
6301 // Load values from matrix A
6302 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
6303#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6304 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
6305#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6306#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6307 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
6308#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6309#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6310 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
6311#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6312#else // defined(REINTERPRET_INPUT_AS_3D)
6313 // Load values from matrix A
6314 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6315#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6316 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6317#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6318#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6319 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6320#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6321#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6322 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6323#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6324#endif // defined(REINTERPRET_INPUT_AS_3D)
6325
6326 // Load values from matrix B
6327 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
6328
6329 src_addr += (int2)(sizeof(half), src1_stride_y);
6330
6331 // Accumulate
6332 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
6333#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6334 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
6335#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6336#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6337 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
6338#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6339#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6340 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
6341#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6342 }
6343
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006344 int z = get_global_id(2);
6345
6346 // Compute destination address
6347 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6348
6349 // Compute dst address
6350 __global uchar *dst_addr = offset(&dst, 0, 0);
6351
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006352 uint4 zout = 0;
6353
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006354#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006355
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006356 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
6357 // in order to take into account the presence of possible cross plane paddings
6358 //
6359 // | |
6360 // | plane0 |
6361 // | |
6362 // |__________________|
6363 // |******************|
6364 // | cross_plane_pad |
6365 // |******************|
6366 // | |
6367 // | plane1 |
6368 // | |
6369 // |__________________|
6370
6371 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006372 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6373 zout = min(DEPTH_GEMM3D - 1, zout);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006374
6375 // Add offset due to the cross plane paddings
6376 zout *= (dst_cross_plane_pad * dst_stride_y);
6377
6378 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6379 // multiply dst_stride_z by DEPTH_GEMM3D
6380 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006381#else // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006382 // Add offset for batched GEMM
6383 dst_addr += z * dst_stride_z;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006384#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006385
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006386 // Multiply by the weight of matrix-matrix product and store the result
6387#if defined(ALPHA)
6388 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, acc, ALPHA);
6389#endif // defined(ALPHA)
6390
6391#if defined(BETA)
6392 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
6393
6394#if defined(BROADCAST_BIAS)
6395 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
6396
6397 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6398
6399 float8 bias_f0 = convert_float8(bias0);
6400
6401#ifndef UNIT_BETA
6402 SCALE_BLOCK(1, float, bias_f, BETA);
6403#endif // UNIT_BIAS
6404
6405 // acc = acc + bias[broadcasted]
6406 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f0);
6407
6408#else // defined(BROADCAST_BIAS)
6409 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
6410 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
6411
6412 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6413
6414 float8 bias_f0 = convert_float8(bias0);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006415#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006416 float8 bias_f1 = convert_float8(bias1);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006417#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6418#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006419 float8 bias_f2 = convert_float8(bias2);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006420#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6421#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006422 float8 bias_f3 = convert_float8(bias3);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006423#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006424
6425#ifndef UNIT_BETA
6426 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, float, bias_f, BETA);
6427#endif // UNIT_BIAS
6428
6429 // acc = acc + bias
6430 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias_f);
6431
6432#endif // defined(BROADCAST_BIAS)
6433#endif // defined(BETA)
6434
6435 half8 acc_h0 = convert_half8(acc0);
6436#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6437 half8 acc_h1 = convert_half8(acc1);
6438#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6439#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6440 half8 acc_h2 = convert_half8(acc2);
6441#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6442#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6443 half8 acc_h3 = convert_half8(acc3);
6444#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6445
6446#if defined(ACTIVATION_TYPE)
6447 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc_h, A_VAL, B_VAL);
6448#endif // defined(ACTIVATION_TYPE)
6449
6450 // Store the output block
6451 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc_h, dst_addr, dst_stride_y, zout.s);
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00006452}
6453
6454/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
6455 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006456 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
6457 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
6458 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
6459 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
6460 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006461 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
6462 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006463 *
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006464 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
6465 * The activation function is performed after the bias addition
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006466 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
6467 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006468 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
6469 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
6470 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
6471 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
6472 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006473 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
6474 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6475 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6476 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6477 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6478 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
6479 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
6480 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6481 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6482 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6483 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6484 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006485 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
6486 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
6487 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
6488 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
6489 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
6490 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006491 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
6492 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6493 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6494 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6495 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6496 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006497 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
6498 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006499 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006500 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006501 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
6502 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006503 */
6504__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
6505 IMAGE_DECLARATION(src1),
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006506#if defined(BETA)
6507 IMAGE_DECLARATION(src2),
6508#endif // defined(BETA)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006509 IMAGE_DECLARATION(dst),
6510 uint src0_stride_z,
6511 uint src1_stride_z,
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006512#if defined(BETA)
6513 uint src2_stride_z,
6514#endif //defined(BETA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006515 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006516#if defined(REINTERPRET_INPUT_AS_3D)
6517 ,
6518 uint src_cross_plane_pad
6519#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006520#if defined(REINTERPRET_OUTPUT_AS_3D)
6521 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006522 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006523#endif // REINTERPRET_OUTPUT_AS_3D
6524 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006525{
6526 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
6527
6528 // Compute starting address for matrix A and Matrix B
6529 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
6530
6531 // Update address for the matrix A
6532 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
6533
6534 // Update address for the matrix B
6535 src_addr.s1 += idx * sizeof(half);
6536
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006537#if defined(REINTERPRET_INPUT_AS_3D)
6538 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
6539 // in order to take into account the presence of possible cross plane paddings
6540 //
6541 // | |
6542 // | plane0 |
6543 // | |
6544 // |__________________|
6545 // |******************|
6546 // | cross_plane_pad |
6547 // |******************|
6548 // | |
6549 // | plane1 |
6550 // | |
6551 // |__________________|
6552
6553 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
6554 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6555 zin = min(DEPTH_GEMM3D - 1, zin);
6556
6557 // Add offset due to the cross plane paddings
6558 zin *= (src_cross_plane_pad * src0_stride_y);
6559
6560 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6561 // multiply src0_stride_z by DEPTH_GEMM3D
6562 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
6563
6564#else // defined(REINTERPRET_INPUT_AS_3D)
6565
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006566 // Add offset for batched GEMM
6567 src_addr.s0 += get_global_id(2) * src0_stride_z;
6568
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006569#endif // defined(REINTERPRET_INPUT_AS_3D)
6570
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006571#if defined(MATRIX_B_DEPTH)
6572 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
6573 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
6574#else // defined(MATRIX_B_DEPTH)
6575 src_addr.s1 += get_global_id(2) * src1_stride_z;
6576#endif // defined(MATRIX_B_DEPTH)
6577
6578 half8 acc0 = 0.0h;
6579#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6580 half8 acc1 = 0.0h;
6581#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6582#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6583 half8 acc2 = 0.0h;
6584#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6585#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6586 half8 acc3 = 0.0h;
6587#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6588
6589 int i = 0;
6590 for(; i <= ((int)COLS_A - 4); i += 4)
6591 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006592#if defined(REINTERPRET_INPUT_AS_3D)
6593 // Load values from matrix A
Usama Arif0681e3b2019-04-25 14:28:07 +01006594 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
6595#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006596 // Load values from matrix A
6597 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6598#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6599 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6600#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6601#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6602 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6603#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6604#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6605 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6606#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006607#endif // defined(REINTERPRET_INPUT_AS_3D)
6608
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006609 // Load values from matrix B
6610 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6611 src_addr.s1 += src1_stride_y;
6612
6613 // Accumulate
6614 acc0 = fma(b0, (half8)a0.s0, acc0);
6615#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6616 acc1 = fma(b0, (half8)a1.s0, acc1);
6617#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6618#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6619 acc2 = fma(b0, (half8)a2.s0, acc2);
6620#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6621#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6622 acc3 = fma(b0, (half8)a3.s0, acc3);
6623#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6624
6625 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6626 src_addr.s1 += src1_stride_y;
6627 acc0 = fma(b0, (half8)a0.s1, acc0);
6628#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6629 acc1 = fma(b0, (half8)a1.s1, acc1);
6630#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6631#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6632 acc2 = fma(b0, (half8)a2.s1, acc2);
6633#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6634#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6635 acc3 = fma(b0, (half8)a3.s1, acc3);
6636#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6637
6638 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6639 src_addr.s1 += src1_stride_y;
6640 acc0 = fma(b0, (half8)a0.s2, acc0);
6641#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6642 acc1 = fma(b0, (half8)a1.s2, acc1);
6643#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6644#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6645 acc2 = fma(b0, (half8)a2.s2, acc2);
6646#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6647#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6648 acc3 = fma(b0, (half8)a3.s2, acc3);
6649#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6650
6651 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6652 src_addr.s1 += src1_stride_y;
6653 acc0 = fma(b0, (half8)a0.s3, acc0);
6654#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6655 acc1 = fma(b0, (half8)a1.s3, acc1);
6656#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6657#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6658 acc2 = fma(b0, (half8)a2.s3, acc2);
6659#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6660#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6661 acc3 = fma(b0, (half8)a3.s3, acc3);
6662#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6663
6664 src_addr.s0 += 4 * sizeof(half);
6665 }
6666
6667 for(; i < (int)COLS_A; ++i)
6668 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006669#if defined(REINTERPRET_INPUT_AS_3D)
6670 // Load values from matrix A
6671 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
6672#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6673 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
6674#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6675#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6676 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
6677#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6678#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6679 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
6680#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6681#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006682 // Load values from matrix A
6683 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
6684#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6685 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
6686#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6687#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6688 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
6689#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6690#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6691 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
6692#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006693#endif // defined(REINTERPRET_INPUT_AS_3D)
6694
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006695 // Load values from matrix B
6696 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
6697
6698 src_addr += (int2)(sizeof(half), src1_stride_y);
6699
6700 // Accumulate
6701 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
6702#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6703 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
6704#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
6705#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6706 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
6707#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
6708#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6709 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
6710#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
6711 }
6712
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006713 int z = get_global_id(2);
6714
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006715 // Compute destination address
6716 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6717
6718 // Compute dst address
6719 __global uchar *dst_addr = offset(&dst, 0, 0);
6720
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006721 uint4 zout = 0;
6722
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006723#if defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006724
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006725 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006726 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006727 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006728 // | |
6729 // | plane0 |
6730 // | |
6731 // |__________________|
6732 // |******************|
6733 // | cross_plane_pad |
6734 // |******************|
6735 // | |
6736 // | plane1 |
6737 // | |
6738 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006739
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006740 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006741 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
6742 zout = min(DEPTH_GEMM3D - 1, zout);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006743
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01006744 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01006745 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006746
6747 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
6748 // multiply dst_stride_z by DEPTH_GEMM3D
6749 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006750#else // defined(REINTERPRET_OUTPUT_AS_3D)
6751 // Add offset for batched GEMM
6752 dst_addr += z * dst_stride_z;
6753#endif // defined(REINTERPRET_OUTPUT_AS_3D)
6754
6755 // Multiply by the weight of matrix-matrix product and store the result
6756#if defined(ALPHA)
6757 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, acc, ALPHA);
6758#endif // defined(ALPHA)
6759
6760 // Add beta*bias
6761#if defined(BETA)
6762 REPEAT_VAR_INIT_TO_CONST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, uint, zero, 0);
6763
6764#if defined(BROADCAST_BIAS)
6765 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
6766
6767 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6768
6769#ifndef UNIT_BETA
6770 SCALE_BLOCK(1, half, bias, BETA);
6771#endif // UNIT_BIAS
6772
6773 // acc = acc + bias[broadcasted]
6774 ADD_BLOCK_BROADCAST(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias0);
6775
6776#else // defined(BROADCAST_BIAS)
6777 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) *
6778 (uint)NUM_ELEMS_PROCESSED_PER_THREAD_Y * src2_stride_y) + get_global_id(2) * src2_stride_z;
6779
6780 LOAD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
6781
6782#ifndef UNIT_BETA
6783 SCALE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, half, bias, BETA);
6784#endif // UNIT_BIAS
6785
6786 // acc = acc + bias
6787 ADD_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, acc, bias);
6788
6789#endif // defined(BROADCAST_BIAS)
6790#endif // defined(BETA)
6791
6792#if defined(ACTIVATION_TYPE)
6793 ACTIVATION_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, ACTIVATION_TYPE, half, acc, A_VAL, B_VAL);
6794#endif // defined(ACTIVATION_TYPE)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006795
6796 // Store the output block
Usama Arif0681e3b2019-04-25 14:28:07 +01006797 STORE_BLOCK(NUM_ELEMS_PROCESSED_PER_THREAD_Y, 8, half, acc, dst_addr, dst_stride_y, zout.s);
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006798}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01006799#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01006800
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01006801#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006802
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006803#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006804/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6805 *
Gian Marco19835e52018-01-30 13:35:54 +00006806 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006807 *
6808 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
6809 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6810 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6811 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6812 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006813 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6814 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006815 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006816 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006817 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6818 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6819 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6820 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006821 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6822 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006823 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6824 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006825__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
6826 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006827{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006828 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006829 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6830 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006831
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006832 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006833 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
6834
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006835 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006836 float4 c = vload4(0, (__global float *)src.ptr);
6837
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006838 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006839 float4 out = alpha_ab + (float4)BETA * c;
6840
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006841 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006842 vstore4(out, 0, (__global float *)dst.ptr);
6843}
6844
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006845#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006846/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
6847 *
Gian Marco19835e52018-01-30 13:35:54 +00006848 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006849 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006850 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
6851 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
6852 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6853 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
6854 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006855 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
6856 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006857 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006858 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006859 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6860 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6861 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6862 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006863 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
6864 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006865 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6866 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006867__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
6868 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006869{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006870 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00006871 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
6872 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006873
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006874 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006875 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
6876
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006877 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006878 half8 c = vload8(0, (__global half *)src.ptr);
6879
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006880 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006881 half8 out = alpha_ab + (half8)BETA * c;
6882
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006883 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006884 vstore8(out, 0, (__global half *)dst.ptr);
6885}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01006886#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006887#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006888
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006889#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006890/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
6891 *
Gian Marco19835e52018-01-30 13:35:54 +00006892 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006893 *
Gian Marco19835e52018-01-30 13:35:54 +00006894 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006895 *
6896 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
6897 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
6898 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6899 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
6900 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6901 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006902 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006903 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
6904 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
6905 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
6906 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6907 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
6908 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
6909 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01006910 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006911 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
6912 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
6913 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
6914 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
6915 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
6916 */
6917__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
6918 TENSOR3D_DECLARATION(src1),
6919 IMAGE_DECLARATION(dst))
6920{
6921 int idx = get_global_id(0) * 4;
6922 int idy = get_global_id(1);
6923
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006924 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006925 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
6926 src_addr.s1 += idx * sizeof(float);
6927
6928 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
6929
6930 float4 acc = 0.0f;
6931
Georgios Pinitas96880cf2017-10-20 18:52:20 +01006932 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006933 {
6934 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
6935 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6936 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
6937
6938 acc += b0 * (float4)a0.s0;
6939 acc += b1 * (float4)a0.s1;
6940 }
6941
6942 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
6943 {
6944 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
6945 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
6946
6947 acc += b0 * (float4)a0;
6948 }
6949
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006950 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01006951 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
6952
6953 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
6954}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006955#endif // defined(WIDTH_VECTOR_A)
6956
6957/** This kernel accumulates each row with the biases vector.
6958 *
6959 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
6960 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
6961 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006962 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006963 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
6964 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
6965 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
6966 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
6967 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
6968 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
6969 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
6970 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
6971 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
6972 */
6973#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
6974__kernel void gemm_accumulate_biases(
6975 IMAGE_DECLARATION(accum),
6976 VECTOR_DECLARATION(biases))
6977{
6978 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
6979 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
6980
Gian Marco Iodiced1f54762019-07-19 09:54:47 +01006981 // Vector size, e.g. number of vector elements.
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006982 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6983 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
6984 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
6985 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01006986 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00006987 // Store result in the accumulate buffer
6988 VSTORE(VECTOR_SIZE)
6989 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
6990}
6991#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)