blob: cf1e02192975cc2268229cf3ae7e8905f6459aff [file] [log] [blame]
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001/*
Gian Marco36a0a462018-01-12 10:21:40 +00002 * Copyright (c) 2017-2018 ARM Limited.
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +000026#if defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE)
27
28/** This OpenCL kernel reshapes the lhs input matrix. The kernel splits the input matrix in blocks of size M0xK0 and stores each one (not transposed) in
29 * the output matrix unrolling the values.
30 *
31 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
32 * @note The block's dimensions (M0 and K0) must be passed at compile time using -DM0 and -DK0 (i.e. -DM0=2, -DK0=2).
33 * @note The number of M0xK0 vertical blocks to store on the same output row must be passed at compile time using -DV0 (i.e. -DV0=2)
34 * @note Only the following values for M0, K0 and V0 are supported:
35 * M0: 2,3,4,5,6,7,8
36 * K0: 2,4,8,16
37 * V0: greater than 0
38 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
39 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
40 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
41 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
42 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
43 * @note If the M0xK0 blocks have to be interleaved, the option -DINTERLEAVE must passed at compile time.
44 *
45 * @param[in] src_ptr Pointer to the source LHS tensor. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
46 * @param[in] src_stride_x Stride of the source LHS tensor in X dimension (in bytes)
47 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
48 * @param[in] src_stride_y Stride of the source LHS tensor in Y dimension (in bytes)
49 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
50 * @param[in] src_stride_z Stride of the source LHS tensor in Z dimension (in bytes)
51 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
52 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source LHS tensor
53 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
54 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
55 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
56 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
57 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
58 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
59 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
60 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
61 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
62 */
63__kernel void gemm_reshape_lhs_matrix_nt(TENSOR3D_DECLARATION(src),
64 TENSOR3D_DECLARATION(dst)
65#if defined(REINTERPRET_INPUT_AS_3D)
66 ,
67 uint cross_plane_pad
68#endif // REINTERPRET_INPUT_AS_3D
69 )
70{
71// Block size
72#define BLOCK_SIZE ((M0) * (K0))
73
74// Output offset X
75#if defined(INTERLEAVE)
76#define OUTPUT_OFFSET_X (K0)
77#else // defined(INTERLEAVE)
78#define OUTPUT_OFFSET_X (BLOCK_SIZE)
79#endif // defined(INTERLEAVE)
80
81// Output step X
82#if defined(INTERLEAVE)
83#define OUTPUT_STEP_X (K0) * (V0)
84#else // Do not interleave
85#define OUTPUT_STEP_X (K0)
86#endif // defined(INTERLEAVE)
87
88 // Compute source and destination addresses
89 uint x = get_global_id(0);
90 uint y = get_global_id(1);
91 uint z = get_global_id(2);
92
93 // ------------------ Compute input/output addresses ---------------------------
94
95 // Compute the input address
96 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * (uint)K0 * sizeof(DATA_TYPE) + y * (uint)M0 * src_stride_y;
97
98 // Compute the output address
99 __global uchar *output_ptr = dst_ptr + dst_offset_first_element_in_bytes + (x * (uint)BLOCK_SIZE * (uint)V0 * sizeof(DATA_TYPE)) + ((y / (uint)V0) * (uint)dst_stride_y) + ((y % V0) *
100 (uint)OUTPUT_OFFSET_X * sizeof(DATA_TYPE));
101
102 uint zin0 = 0;
103 uint zin1 = 0;
104 uint zin2 = 0;
105 uint zin3 = 0;
106 uint zin4 = 0;
107 uint zin5 = 0;
108 uint zin6 = 0;
109 uint zin7 = 0;
110
111#if defined(REINTERPRET_INPUT_AS_3D)
112 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
113 // multiply src_stride_z by DEPTH_GEMM3D
114
115 // Note for the REINTERPRET_INPUT_AS_3D case
116 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
117 // in order to take into account the presence of possible cross plane paddings
118 //
119 // | |
120 // | plane0 |
121 // | |
122 // |__________________|
123 // |******************|
124 // | cross_plane_pad |
125 // |******************|
126 // | |
127 // | plane1 |
128 // | |
129 // |__________________|
130
131 input_ptr += z * (uint)src_stride_z * DEPTH_GEMM3D;
132
133 // The plane (zin) is calculated dividing M (y * M0) by HEIGHT_GEMM3D
134 zin0 = (0 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
135 zin0 = min((uint)(DEPTH_GEMM3D - 1), zin0);
136 zin0 *= (cross_plane_pad * src_stride_y);
137#if M0 > 1
138 zin1 = (1 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
139 zin1 = min((uint)(DEPTH_GEMM3D - 1), zin1);
140 zin1 *= (cross_plane_pad * src_stride_y);
141#endif // M0 > 1
142#if M0 > 2
143 zin2 = (2 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
144 zin2 = min((uint)(DEPTH_GEMM3D - 1), zin2);
145 zin2 *= (cross_plane_pad * src_stride_y);
146#endif // M0 > 2
147#if M0 > 3
148 zin3 = (3 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
149 zin3 = min((uint)(DEPTH_GEMM3D - 1), zin3);
150 zin3 *= (cross_plane_pad * src_stride_y);
151#endif // M0 > 3
152#if M0 > 4
153 zin4 = (4 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
154 zin4 = min((uint)(DEPTH_GEMM3D - 1), zin4);
155 zin4 *= (cross_plane_pad * src_stride_y);
156#endif // M0 > 4
157#if M0 > 5
158 zin5 = (5 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
159 zin5 = min((uint)(DEPTH_GEMM3D - 1), zin5);
160 zin5 *= (cross_plane_pad * src_stride_y);
161#endif // M0 > 5
162#if M0 > 6
163 zin6 = (6 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
164 zin6 = min((uint)(DEPTH_GEMM3D - 1), zin6);
165 zin6 *= (cross_plane_pad * src_stride_y);
166#endif // M0 > 6
167#if M0 > 6
168 zin7 = (7 + (uint)(y * M0)) / (uint)HEIGHT_GEMM3D;
169 zin7 = min((uint)(DEPTH_GEMM3D - 1), zin7);
170 zin7 *= (cross_plane_pad * src_stride_y);
171#endif // M0 > 7
172
173#else // defined(REINTERPRET_INPUT_AS_3D)
174
175 input_ptr += z * (uint)src_stride_z;
176
177#endif // defined(REINTERPRET_INPUT_AS_3D)
178
179 // Add offset for batched GEMM
180 output_ptr += z * (uint)dst_stride_z;
181
182 // ---------------------------Load input values --------------------------------
183
184 // Load values from the LHS matrix
185 VEC_DATA_TYPE(DATA_TYPE, K0)
186 a0 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin0));
187#if M0 > 1
188 VEC_DATA_TYPE(DATA_TYPE, K0)
189 a1 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin1));
190#endif // M0 > 1
191#if M0 > 2
192 VEC_DATA_TYPE(DATA_TYPE, K0)
193 a2 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin2));
194#endif // M0 > 2
195#if M0 > 3
196 VEC_DATA_TYPE(DATA_TYPE, K0)
197 a3 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin3));
198#endif // M0 > 3
199#if M0 > 4
200 VEC_DATA_TYPE(DATA_TYPE, K0)
201 a4 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 4 * src_stride_y + zin4));
202#endif // M0 > 4
203#if M0 > 5
204 VEC_DATA_TYPE(DATA_TYPE, K0)
205 a5 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 5 * src_stride_y + zin5));
206#endif // M0 > 5
207#if M0 > 6
208 VEC_DATA_TYPE(DATA_TYPE, K0)
209 a6 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 6 * src_stride_y + zin6));
210#endif // M0 > 6
211#if M0 > 7
212 VEC_DATA_TYPE(DATA_TYPE, K0)
213 a7 = VLOAD(K0)(0, (__global DATA_TYPE *)(input_ptr + 7 * src_stride_y + zin7));
214#endif // M0 > 7
215
216 // ---------------------------Store output values ------------------------------
217
218 VSTORE(K0)
219 (a0, 0, (__global DATA_TYPE *)(output_ptr + 0 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
220#if M0 > 1
221 VSTORE(K0)
222 (a1, 0, (__global DATA_TYPE *)(output_ptr + 1 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
223#endif // M0 > 1
224#if M0 > 2
225 VSTORE(K0)
226 (a2, 0, (__global DATA_TYPE *)(output_ptr + 2 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
227#endif // M0 > 2
228#if M0 > 3
229 VSTORE(K0)
230 (a3, 0, (__global DATA_TYPE *)(output_ptr + 3 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
231#endif // M0 > 3
232#if M0 > 4
233 VSTORE(K0)
234 (a4, 0, (__global DATA_TYPE *)(output_ptr + 4 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
235#endif // M0 > 4
236#if M0 > 5
237 VSTORE(K0)
238 (a5, 0, (__global DATA_TYPE *)(output_ptr + 5 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
239#endif // M0 > 5
240#if M0 > 6
241 VSTORE(K0)
242 (a6, 0, (__global DATA_TYPE *)(output_ptr + 6 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
243#endif // M0 > 6
244#if M0 > 7
245 VSTORE(K0)
246 (a7, 0, (__global DATA_TYPE *)(output_ptr + 7 * OUTPUT_STEP_X * sizeof(DATA_TYPE)));
247#endif // M0 > 7
248
249#undef BLOCK_SIZE
250#undef OUTPUT_OFFSET_X
251#undef OUTPUT_STEP_X
252}
253#endif // defined(M0) && defined(K0) && defined(V0) && defined(DATA_TYPE)
254
Gian Marco36a0a462018-01-12 10:21:40 +0000255#if defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
256
Gian Marco19835e52018-01-30 13:35:54 +0000257#if ELEMENT_SIZE == 1
Gian Marco36a0a462018-01-12 10:21:40 +0000258#define DATA_TYPE uchar
Gian Marco19835e52018-01-30 13:35:54 +0000259#elif ELEMENT_SIZE == 2
260#define DATA_TYPE ushort
261#elif ELEMENT_SIZE == 4
262#define DATA_TYPE uint
263#else // ELEMENT_SIZE == 1
264#error "Element size not supported"
265#endif // ELEMENT_SIZE
Gian Marco36a0a462018-01-12 10:21:40 +0000266
267/** This OpenCL kernel computes the "vector" 1xW transposition of input matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100268 *
Gian Marco19835e52018-01-30 13:35:54 +0000269 * @note The transposition width must be passed at compile time using -DTRANSPOSE_W (i.e. -DTRANSPOSE_W)
270 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
Gian Marco36a0a462018-01-12 10:21:40 +0000271 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100272 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100273 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
274 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
276 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000277 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
278 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100279 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100280 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100281 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000282 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100283 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000284 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000285 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
286 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100287 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
288 */
Gian Marcoae2af742018-02-15 12:35:44 +0000289__kernel void gemm_transpose1xW(TENSOR3D_DECLARATION(src),
290 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100291{
292 uint x = get_global_id(0);
293 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000294 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100295
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100296 // Compute address for Matrix B - source
Gian Marcoae2af742018-02-15 12:35:44 +0000297 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100298
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100299 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000300 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + y * TRANSPOSE_W * sizeof(DATA_TYPE) * MULT_TRANSPOSE1XW_WIDTH + (x / MULT_TRANSPOSE1XW_WIDTH) * dst_stride_y +
301 (x % MULT_TRANSPOSE1XW_WIDTH) * TRANSPOSE_W * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100302
Gian Marcoae2af742018-02-15 12:35:44 +0000303 // Add offset for batched GEMM
304 dst_addr_in_bytes += z * dst_stride_z;
305
Gian Marco36a0a462018-01-12 10:21:40 +0000306 VEC_DATA_TYPE(DATA_TYPE, TRANSPOSE_W)
307 b0 = VLOAD(TRANSPOSE_W)(0, (__global DATA_TYPE *)src.ptr);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100308
Gian Marco36a0a462018-01-12 10:21:40 +0000309 VSTORE(TRANSPOSE_W)
310 (b0, 0, (__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100311}
Gian Marco36a0a462018-01-12 10:21:40 +0000312#endif // defined(TRANSPOSE_W) && defined(MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100313
Gian Marco36a0a462018-01-12 10:21:40 +0000314#if defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
315
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100316/** This OpenCL kernel reshapes the input matrix transposing each 4x4 block. If -DUNROLL_BLOCK is passed at compile time, the 4x4 block
317 * will be simply unrolled.
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100318 *
Gian Marco19835e52018-01-30 13:35:54 +0000319 * @note The data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=float)
320 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100321 * @note In case the input has to be reinterpreted as a 3D tensor (i.e. input of convolution layer 1x1), the following information must be passed at compile time:
322 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
323 * -# HEIGHT_GEMM3D: The height of the input in case it has to be reinterpreted as a 3D tensor.
324 * -# DEPTH_GEMM3D: The depth of the input in case it has to be reinterpreted as a 3D tensor
325 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
Gian Marco19835e52018-01-30 13:35:54 +0000326 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +0100327 * @param[in] src_ptr Pointer to the source matrix. Supported data types: U8/S8/QASYMM8/U16/S16/F16/U32/S32/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100328 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
329 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
330 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
331 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000332 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
333 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100334 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100335 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100336 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
337 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
338 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
339 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Gian Marcoae2af742018-02-15 12:35:44 +0000340 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
341 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100342 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100343 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100344 */
Gian Marcoae2af742018-02-15 12:35:44 +0000345__kernel void gemm_interleave4x4(TENSOR3D_DECLARATION(src),
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100346 TENSOR3D_DECLARATION(dst)
347#if defined(REINTERPRET_INPUT_AS_3D)
348 ,
349 uint cross_plane_pad
350#endif // REINTERPRET_INPUT_AS_3D
351 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100352{
Gian Marco36a0a462018-01-12 10:21:40 +0000353 // Compute source and destination addresses
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100354 uint x = get_global_id(0);
355 uint y = get_global_id(1);
Gian Marcoae2af742018-02-15 12:35:44 +0000356 uint z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100357
Gian Marcoae2af742018-02-15 12:35:44 +0000358 // Compute address for source tensor
359 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100360
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000361 // Compute address for Matrix B transposed - destination. X and Y are swapped
Gian Marco36a0a462018-01-12 10:21:40 +0000362 uint dst_addr_in_bytes = dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) * 16 * MULT_INTERLEAVE4X4_HEIGHT + (y / MULT_INTERLEAVE4X4_HEIGHT) * dst_stride_y +
363 (y % MULT_INTERLEAVE4X4_HEIGHT) * 4 * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100364
Gian Marcoae2af742018-02-15 12:35:44 +0000365 // Add offset for batched GEMM
366 dst_addr_in_bytes += z * dst_stride_z;
367
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100368#if defined(REINTERPRET_INPUT_AS_3D)
369 __global uchar *input_ptr = src_ptr + src_offset_first_element_in_bytes + x * 4 * sizeof(DATA_TYPE) + y * 4 * src_stride_y;
370
371 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
372 // in order to take into account the presence of possible cross plane paddings
373 //
374 // | |
375 // | plane0 |
376 // | |
377 // |__________________|
378 // |******************|
379 // | cross_plane_pad |
380 // |******************|
381 // | |
382 // | plane1 |
383 // | |
384 // |__________________|
385
386 // The plane (zin) is calculated dividing M (y * 4) by HEIGHT_GEMM3D
387 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(y * 4)) / (uint4)HEIGHT_GEMM3D;
388 zin = min(DEPTH_GEMM3D - 1, zin);
389
390 // Add offset due to the cross plane paddings
391 zin *= (cross_plane_pad * src_stride_y);
392
393 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
394 // multiply src_stride_z by DEPTH_GEMM3D
395 input_ptr += z * src_stride_z * DEPTH_GEMM3D;
396
397 // Load values from Matrix A
398 VEC_DATA_TYPE(DATA_TYPE, 4)
399 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y + zin.s0));
400 VEC_DATA_TYPE(DATA_TYPE, 4)
401 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y + zin.s1));
402 VEC_DATA_TYPE(DATA_TYPE, 4)
403 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y + zin.s2));
404 VEC_DATA_TYPE(DATA_TYPE, 4)
405 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y + zin.s3));
406#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000407 __global uchar *input_ptr = src.ptr;
408
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000409 // Load values from Matrix A
Gian Marco36a0a462018-01-12 10:21:40 +0000410 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000411 a0 = vload4(0, (__global DATA_TYPE *)(input_ptr + 0 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000412 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000413 a1 = vload4(0, (__global DATA_TYPE *)(input_ptr + 1 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000414 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000415 a2 = vload4(0, (__global DATA_TYPE *)(input_ptr + 2 * src_stride_y));
Gian Marco36a0a462018-01-12 10:21:40 +0000416 VEC_DATA_TYPE(DATA_TYPE, 4)
Gian Marcoae2af742018-02-15 12:35:44 +0000417 a3 = vload4(0, (__global DATA_TYPE *)(input_ptr + 3 * src_stride_y));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100418#endif // defined(REINTERPRET_INPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100419
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100420#if defined(UNROLL_BLOCK)
421 vstore4(a0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
422 vstore4(a1, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
423 vstore4(a2, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
424 vstore4(a3, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice5ba5e092018-12-06 17:13:09 +0000425#else // defined(UNROLL_BLOCK)
Gian Marco36a0a462018-01-12 10:21:40 +0000426 VEC_DATA_TYPE(DATA_TYPE, 4)
427 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s0, a1.s0, a2.s0, a3.s0);
428 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 0 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100429
Gian Marco36a0a462018-01-12 10:21:40 +0000430 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s1, a1.s1, a2.s1, a3.s1);
431 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 4 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100432
Gian Marco36a0a462018-01-12 10:21:40 +0000433 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s2, a1.s2, a2.s2, a3.s2);
434 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 8 * MULT_INTERLEAVE4X4_HEIGHT));
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100435
Gian Marco36a0a462018-01-12 10:21:40 +0000436 val0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(a0.s3, a1.s3, a2.s3, a3.s3);
437 vstore4(val0, 0, ((__global DATA_TYPE *)(dst_ptr + dst_addr_in_bytes) + 12 * MULT_INTERLEAVE4X4_HEIGHT));
Gian Marco Iodice4b908652018-10-18 10:21:02 +0100438#endif // defined(UNROLL_BLOCK)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100439}
Gian Marco36a0a462018-01-12 10:21:40 +0000440#endif // defined(MULT_INTERLEAVE4X4_HEIGHT) && defined(DATA_TYPE)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100441
Gian Marco36a0a462018-01-12 10:21:40 +0000442#if defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100443/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100444 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100445 *
Gian Marco19835e52018-01-30 13:35:54 +0000446 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
447 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
448 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000449 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
450 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100451 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000452 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
453 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
454 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
455 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
456 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
457 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100458 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
459 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
460 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
461 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
462 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
463 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100464 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100465 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
466 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
467 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
468 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
469 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100470 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100471 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000472 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100473 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000474 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100475 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000476 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
477 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
478 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100479 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100480 */
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +0100481__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
482 IMAGE_DECLARATION(src1),
483 IMAGE_DECLARATION(dst),
484 uint src0_stride_z,
485 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000486 uint dst_stride_z
487#if defined(REINTERPRET_OUTPUT_AS_3D)
488 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100489 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000490#endif // REINTERPRET_OUTPUT_AS_3D
491 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100492{
Gian Marco36a0a462018-01-12 10:21:40 +0000493 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
494 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000495 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100496
Gian Marco36a0a462018-01-12 10:21:40 +0000497 // Offset
498 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
499 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100500
Gian Marco36a0a462018-01-12 10:21:40 +0000501 // src_addr_a = address of matrix A
502 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000503 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
504 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
505
506#if defined(MATRIX_B_DEPTH)
507 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
508 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
509#else // defined(MATRIX_B_DEPTH)
510 src1_addr_in_bytes += z * src1_stride_z;
511#endif // defined(MATRIX_B_DEPTH)
512
513 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
514 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100515
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000516 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +0000517 __global float *src_end_addr_b = src_addr_b + COLS_B;
518
519 src_addr_a += offset_row_a;
520 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100521
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000522 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100523 float4 c00 = 0.0f;
524 float4 c10 = 0.0f;
525 float4 c20 = 0.0f;
526 float4 c30 = 0.0f;
527
Gian Marco36a0a462018-01-12 10:21:40 +0000528 for(; src_addr_b <= (src_end_addr_b - (int)(8 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100529 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000530 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000531 float4 a0 = vload4(0, src_addr_a);
532 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100533
534 c00 += (float4)a0.s0 * b0;
535 c10 += (float4)a0.s1 * b0;
536 c20 += (float4)a0.s2 * b0;
537 c30 += (float4)a0.s3 * b0;
538
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000539 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000540 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
541 b0 = vload4(0, src_addr_b + 4 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100542
543 c00 += (float4)a0.s0 * b0;
544 c10 += (float4)a0.s1 * b0;
545 c20 += (float4)a0.s2 * b0;
546 c30 += (float4)a0.s3 * b0;
547 }
548
Gian Marco36a0a462018-01-12 10:21:40 +0000549 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100550 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000551 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +0000552 float4 a0 = vload4(0, src_addr_a);
553 float4 b0 = vload4(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100554
555 c00 += (float4)a0.s0 * b0;
556 c10 += (float4)a0.s1 * b0;
557 c20 += (float4)a0.s2 * b0;
558 c30 += (float4)a0.s3 * b0;
559 }
560
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000561 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100562 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
563
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000564#if defined(ALPHA)
565 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100566 c00 = c00 * (float4)ALPHA;
567 c10 = c10 * (float4)ALPHA;
568 c20 = c20 * (float4)ALPHA;
569 c30 = c30 * (float4)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000570#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100571
Gian Marcoae2af742018-02-15 12:35:44 +0000572 // Compute dst address
573 __global uchar *dst_addr = offset(&dst, 0, 0);
574
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000575#if defined(REINTERPRET_OUTPUT_AS_3D)
576 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100577 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000578 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100579 // | |
580 // | plane0 |
581 // | |
582 // |__________________|
583 // |******************|
584 // | cross_plane_pad |
585 // |******************|
586 // | |
587 // | plane1 |
588 // | |
589 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000590
591 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
592 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
593 zout = min(DEPTH_GEMM3D - 1, zout);
594
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100595 // Add offset due to the cross plane paddings
596 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000597
598 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
599 // multiply dst_stride_z by DEPTH_GEMM3D
600 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
601
602 // Store 4x4 block
603 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
604 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
605 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
606 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
607
608#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000609 // Add offset for batched GEMM
610 dst_addr += z * dst_stride_z;
611
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000612 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000613 vstore4(c00, 0, (__global float *)(dst_addr + 0 * dst_stride_y));
614 vstore4(c10, 0, (__global float *)(dst_addr + 1 * dst_stride_y));
615 vstore4(c20, 0, (__global float *)(dst_addr + 2 * dst_stride_y));
616 vstore4(c30, 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000617#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100618}
619
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000620/** This OpenCL kernel is optimized for Bifrost. It computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100621 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_32bit and @ref gemm_transpose1x4 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100622 *
Gian Marco19835e52018-01-30 13:35:54 +0000623 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
624 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
625 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000626 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
627 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
628 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100629 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000630 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
631 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
632 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
633 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
634 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
635 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100636 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
637 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
638 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
639 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
640 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
641 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100642 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100643 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
644 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
645 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
646 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
647 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100648 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100649 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000650 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100651 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000652 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100653 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000654 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
655 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
656 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100657 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100658 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100659__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
660 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000661 IMAGE_DECLARATION(dst),
662 uint src0_stride_z,
663 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000664 uint dst_stride_z
665#if defined(REINTERPRET_OUTPUT_AS_3D)
666 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100667 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000668#endif // REINTERPRET_OUTPUT_AS_3D
669 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100670{
Gian Marco36a0a462018-01-12 10:21:40 +0000671 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
672 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000673 int z = get_global_id(2);
Gian Marco36a0a462018-01-12 10:21:40 +0000674
675 // Offset
676 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
677 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 4;
678
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100679 // src_addr_a = address of matrix A
680 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000681 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
682 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
683
684#if defined(MATRIX_B_DEPTH)
685 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
686 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
687#else // defined(MATRIX_B_DEPTH)
688 src1_addr_in_bytes += z * src1_stride_z;
689#endif // defined(MATRIX_B_DEPTH)
690
691 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
692 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100693
Gian Marco36a0a462018-01-12 10:21:40 +0000694 src_addr_a += offset_row_a;
695 src_addr_b += offset_row_b;
696
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100697 // Reset accumulators
698 float c00 = 0.0f;
699 float c01 = 0.0f;
700 float c02 = 0.0f;
701 float c03 = 0.0f;
702 float c10 = 0.0f;
703 float c11 = 0.0f;
704 float c12 = 0.0f;
705 float c13 = 0.0f;
706 float c20 = 0.0f;
707 float c21 = 0.0f;
708 float c22 = 0.0f;
709 float c23 = 0.0f;
710 float c30 = 0.0f;
711 float c31 = 0.0f;
712 float c32 = 0.0f;
713 float c33 = 0.0f;
714
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100715#define COLS_MTX_B (COLS_B / (4 * MULT_TRANSPOSE1XW_WIDTH))
716
717 int i = 0;
718 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100719 {
720 // Load values from matrix A (interleaved) and matrix B (transposed)
721 float4 a0 = vload4(0, src_addr_a);
722 float4 b0 = vload4(0, src_addr_b);
723
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100724 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
725 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100726
727 c00 = fma(a0.s0, b0.s0, c00);
728 c01 = fma(a0.s0, b0.s1, c01);
729 c02 = fma(a0.s0, b0.s2, c02);
730 c03 = fma(a0.s0, b0.s3, c03);
731
732 c10 = fma(a0.s1, b0.s0, c10);
733 c11 = fma(a0.s1, b0.s1, c11);
734 c12 = fma(a0.s1, b0.s2, c12);
735 c13 = fma(a0.s1, b0.s3, c13);
736
737 c20 = fma(a0.s2, b0.s0, c20);
738 c21 = fma(a0.s2, b0.s1, c21);
739 c22 = fma(a0.s2, b0.s2, c22);
740 c23 = fma(a0.s2, b0.s3, c23);
741
742 c30 = fma(a0.s3, b0.s0, c30);
743 c31 = fma(a0.s3, b0.s1, c31);
744 c32 = fma(a0.s3, b0.s2, c32);
745 c33 = fma(a0.s3, b0.s3, c33);
746
747 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100748 a0 = vload4(0, src_addr_a);
749 b0 = vload4(0, src_addr_b);
750
751 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
752 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100753
754 c00 = fma(a0.s0, b0.s0, c00);
755 c01 = fma(a0.s0, b0.s1, c01);
756 c02 = fma(a0.s0, b0.s2, c02);
757 c03 = fma(a0.s0, b0.s3, c03);
758
759 c10 = fma(a0.s1, b0.s0, c10);
760 c11 = fma(a0.s1, b0.s1, c11);
761 c12 = fma(a0.s1, b0.s2, c12);
762 c13 = fma(a0.s1, b0.s3, c13);
763
764 c20 = fma(a0.s2, b0.s0, c20);
765 c21 = fma(a0.s2, b0.s1, c21);
766 c22 = fma(a0.s2, b0.s2, c22);
767 c23 = fma(a0.s2, b0.s3, c23);
768
769 c30 = fma(a0.s3, b0.s0, c30);
770 c31 = fma(a0.s3, b0.s1, c31);
771 c32 = fma(a0.s3, b0.s2, c32);
772 c33 = fma(a0.s3, b0.s3, c33);
773
774 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100775 a0 = vload4(0, src_addr_a);
776 b0 = vload4(0, src_addr_b);
777
778 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
779 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
780
781 c00 = fma(a0.s0, b0.s0, c00);
782 c01 = fma(a0.s0, b0.s1, c01);
783 c02 = fma(a0.s0, b0.s2, c02);
784 c03 = fma(a0.s0, b0.s3, c03);
785
786 c10 = fma(a0.s1, b0.s0, c10);
787 c11 = fma(a0.s1, b0.s1, c11);
788 c12 = fma(a0.s1, b0.s2, c12);
789 c13 = fma(a0.s1, b0.s3, c13);
790
791 c20 = fma(a0.s2, b0.s0, c20);
792 c21 = fma(a0.s2, b0.s1, c21);
793 c22 = fma(a0.s2, b0.s2, c22);
794 c23 = fma(a0.s2, b0.s3, c23);
795
796 c30 = fma(a0.s3, b0.s0, c30);
797 c31 = fma(a0.s3, b0.s1, c31);
798 c32 = fma(a0.s3, b0.s2, c32);
799 c33 = fma(a0.s3, b0.s3, c33);
800
801 // Load values from matrix A (interleaved) and matrix B (transposed)
802 a0 = vload4(0, src_addr_a);
803 b0 = vload4(0, src_addr_b);
804
805 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
806 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100807
808 c00 = fma(a0.s0, b0.s0, c00);
809 c01 = fma(a0.s0, b0.s1, c01);
810 c02 = fma(a0.s0, b0.s2, c02);
811 c03 = fma(a0.s0, b0.s3, c03);
812
813 c10 = fma(a0.s1, b0.s0, c10);
814 c11 = fma(a0.s1, b0.s1, c11);
815 c12 = fma(a0.s1, b0.s2, c12);
816 c13 = fma(a0.s1, b0.s3, c13);
817
818 c20 = fma(a0.s2, b0.s0, c20);
819 c21 = fma(a0.s2, b0.s1, c21);
820 c22 = fma(a0.s2, b0.s2, c22);
821 c23 = fma(a0.s2, b0.s3, c23);
822
823 c30 = fma(a0.s3, b0.s0, c30);
824 c31 = fma(a0.s3, b0.s1, c31);
825 c32 = fma(a0.s3, b0.s2, c32);
826 c33 = fma(a0.s3, b0.s3, c33);
827 }
828
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100829 for(; i < (int)(COLS_MTX_B); ++i)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100830 {
831 // Load values from matrix A (interleaved) and matrix B (transposed)
832 float4 a0 = vload4(0, src_addr_a);
833 float4 b0 = vload4(0, src_addr_b);
834
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +0100835 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
836 src_addr_b += 4 * MULT_TRANSPOSE1XW_WIDTH;
837
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100838 c00 = fma(a0.s0, b0.s0, c00);
839 c01 = fma(a0.s0, b0.s1, c01);
840 c02 = fma(a0.s0, b0.s2, c02);
841 c03 = fma(a0.s0, b0.s3, c03);
842
843 c10 = fma(a0.s1, b0.s0, c10);
844 c11 = fma(a0.s1, b0.s1, c11);
845 c12 = fma(a0.s1, b0.s2, c12);
846 c13 = fma(a0.s1, b0.s3, c13);
847
848 c20 = fma(a0.s2, b0.s0, c20);
849 c21 = fma(a0.s2, b0.s1, c21);
850 c22 = fma(a0.s2, b0.s2, c22);
851 c23 = fma(a0.s2, b0.s3, c23);
852
853 c30 = fma(a0.s3, b0.s0, c30);
854 c31 = fma(a0.s3, b0.s1, c31);
855 c32 = fma(a0.s3, b0.s2, c32);
856 c33 = fma(a0.s3, b0.s3, c33);
857 }
858
859 // Compute destination address
860 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
861
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000862#if defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100863 // Multiply by the weight of matrix product
864 c00 = c00 * ALPHA;
865 c01 = c01 * ALPHA;
866 c02 = c02 * ALPHA;
867 c03 = c03 * ALPHA;
868 c10 = c10 * ALPHA;
869 c11 = c11 * ALPHA;
870 c12 = c12 * ALPHA;
871 c13 = c13 * ALPHA;
872 c20 = c20 * ALPHA;
873 c21 = c21 * ALPHA;
874 c22 = c22 * ALPHA;
875 c23 = c23 * ALPHA;
876 c30 = c30 * ALPHA;
877 c31 = c31 * ALPHA;
878 c32 = c32 * ALPHA;
879 c33 = c33 * ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +0000880#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100881
Gian Marcoae2af742018-02-15 12:35:44 +0000882 // Compute dst address
883 __global uchar *dst_addr = offset(&dst, 0, 0);
884
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000885#if defined(REINTERPRET_OUTPUT_AS_3D)
886 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100887 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000888 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100889 // | |
890 // | plane0 |
891 // | |
892 // |__________________|
893 // |******************|
894 // | cross_plane_pad |
895 // |******************|
896 // | |
897 // | plane1 |
898 // | |
899 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000900
901 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
902 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
903 zout = min(DEPTH_GEMM3D - 1, zout);
904
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100905 // Add offset due to the cross plane paddings
906 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000907
908 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
909 // multiply dst_stride_z by DEPTH_GEMM3D
910 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
911
912 // Store 4x4 block
913 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
914 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
915 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
916 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
917
918#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +0000919 // Add offset for batched GEMM
920 dst_addr += z * dst_stride_z;
921
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100922 // Store 4x4 block
Gian Marcoae2af742018-02-15 12:35:44 +0000923 vstore4((float4)(c00, c01, c02, c03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
924 vstore4((float4)(c10, c11, c12, c13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
925 vstore4((float4)(c20, c21, c22, c23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
926 vstore4((float4)(c30, c31, c32, c33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000927#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100928}
929
Georgios Pinitas84225582018-05-14 12:00:05 +0100930// Undefine local defines
931#undef COLS_MTX_B
932
Matthew Bentham6f31f8c2017-10-27 11:50:06 +0100933#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100934/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100935 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100936 *
Gian Marco19835e52018-01-30 13:35:54 +0000937 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
938 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
939 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000940 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
941 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100942 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000943 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
944 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
945 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
946 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
947 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
948 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100949 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
950 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
951 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
952 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
953 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
954 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100955 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100956 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
957 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
958 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
959 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
960 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +0100961 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100962 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000963 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100964 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
Gian Marco36a0a462018-01-12 10:21:40 +0000965 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100966 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000967 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
968 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
969 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +0100970 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100971 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +0100972__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
973 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +0000974 IMAGE_DECLARATION(dst),
975 uint src0_stride_z,
976 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000977 uint dst_stride_z
978#if defined(REINTERPRET_OUTPUT_AS_3D)
979 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +0100980 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +0000981#endif // REINTERPRET_OUTPUT_AS_3D
982 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100983{
Gian Marco36a0a462018-01-12 10:21:40 +0000984 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
985 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
Gian Marcoae2af742018-02-15 12:35:44 +0000986 int z = get_global_id(2);
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100987
Gian Marco36a0a462018-01-12 10:21:40 +0000988 // Offset
989 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
990 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
Anthony Barbier6ff3b192017-09-04 18:44:23 +0100991
Gian Marco36a0a462018-01-12 10:21:40 +0000992 // src_addr_a = address of matrix A
993 // src_addr_b = address of matrix B
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000994 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
995 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
996
997#if defined(MATRIX_B_DEPTH)
998 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
999 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1000#else // defined(MATRIX_B_DEPTH)
1001 src1_addr_in_bytes += z * src1_stride_z;
1002#endif // defined(MATRIX_B_DEPTH)
1003
1004 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
1005 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001006
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001007 // Compute end row address for matrix B
Gian Marco36a0a462018-01-12 10:21:40 +00001008 __global half *src_end_addr_b = src_addr_b + COLS_B;
1009
1010 src_addr_a += offset_row_a;
1011 src_addr_b += offset_row_b;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001012
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001013 // Reset accumulators
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001014 half8 c00 = 0.0f;
1015 half8 c10 = 0.0f;
1016 half8 c20 = 0.0f;
1017 half8 c30 = 0.0f;
1018
Gian Marco36a0a462018-01-12 10:21:40 +00001019 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001020 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001021 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001022 half4 a0 = vload4(0, src_addr_a);
1023 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001024
1025 c00 += (half8)a0.s0 * b0;
1026 c10 += (half8)a0.s1 * b0;
1027 c20 += (half8)a0.s2 * b0;
1028 c30 += (half8)a0.s3 * b0;
1029
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001030 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001031 a0 = vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT);
1032 b0 = vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001033
1034 c00 += (half8)a0.s0 * b0;
1035 c10 += (half8)a0.s1 * b0;
1036 c20 += (half8)a0.s2 * b0;
1037 c30 += (half8)a0.s3 * b0;
1038 }
1039
Gian Marco36a0a462018-01-12 10:21:40 +00001040 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001041 {
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001042 // Load values from matrix A (interleaved) and matrix B (transposed)
Gian Marco36a0a462018-01-12 10:21:40 +00001043 half4 a0 = vload4(0, src_addr_a);
1044 half8 b0 = vload8(0, src_addr_b);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001045
1046 c00 += (half8)a0.s0 * b0;
1047 c10 += (half8)a0.s1 * b0;
1048 c20 += (half8)a0.s2 * b0;
1049 c30 += (half8)a0.s3 * b0;
1050 }
1051
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001052 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001053 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1054
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001055#if defined(ALPHA)
1056 // Multiply by the weight of matrix product
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001057 c00 = c00 * (half8)ALPHA;
1058 c10 = c10 * (half8)ALPHA;
1059 c20 = c20 * (half8)ALPHA;
1060 c30 = c30 * (half8)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001061#endif // defined(ALPHA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001062
Gian Marcoae2af742018-02-15 12:35:44 +00001063 // Compute dst address
1064 __global uchar *dst_addr = offset(&dst, 0, 0);
1065
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001066#if defined(REINTERPRET_OUTPUT_AS_3D)
1067 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001068 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001069 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001070 // | |
1071 // | plane0 |
1072 // | |
1073 // |__________________|
1074 // |******************|
1075 // | cross_plane_pad |
1076 // |******************|
1077 // | |
1078 // | plane1 |
1079 // | |
1080 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001081
1082 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1083 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1084 zout = min(DEPTH_GEMM3D - 1, zout);
1085
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001086 // Add offset due to the cross plane paddings
1087 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001088
1089 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1090 // multiply dst_stride_z by DEPTH_GEMM3D
1091 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1092
1093 // Store 4x8 block
1094 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
1095 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
1096 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
1097 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
1098
1099#else // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marcoae2af742018-02-15 12:35:44 +00001100 // Add offset for batched GEMM
1101 dst_addr += z * dst_stride_z;
1102
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001103 // Store 4x8 block
Gian Marcoae2af742018-02-15 12:35:44 +00001104 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
1105 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
1106 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
1107 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001108#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001109}
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001110
Vidhya Sudhan Loganathan38d93bd2018-11-20 15:38:13 +00001111/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) while accumulating the result in a 32 floating point variable.
1112 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
1113 *
1114 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1115 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1116 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
1117 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1118 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1119 *
1120 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1121 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1122 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1123 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1124 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1125 *
1126 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1127 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1128 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1129 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1130 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1131 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1132 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1133 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1134 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1135 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1136 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1137 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1138 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1139 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1140 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1141 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1142 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1143 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1144 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1145 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1146 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1147 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1148 */
1149__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
1150 IMAGE_DECLARATION(src1),
1151 IMAGE_DECLARATION(dst),
1152 uint src0_stride_z,
1153 uint src1_stride_z,
1154 uint dst_stride_z
1155#if defined(REINTERPRET_OUTPUT_AS_3D)
1156 ,
1157 uint cross_plane_pad
1158#endif // REINTERPRET_OUTPUT_AS_3D
1159 )
1160{
1161 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1162 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
1163 int z = get_global_id(2);
1164
1165 // Offset
1166 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1167 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
1168
1169 // src_addr_a = address of matrix A
1170 // src_addr_b = address of matrix B
1171 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1172 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1173
1174#if defined(MATRIX_B_DEPTH)
1175 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1176 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1177#else // defined(MATRIX_B_DEPTH)
1178 src1_addr_in_bytes += z * src1_stride_z;
1179#endif // defined(MATRIX_B_DEPTH)
1180
1181 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
1182 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
1183
1184 // Compute end row address for matrix B
1185 __global half *src_end_addr_b = src_addr_b + COLS_B;
1186
1187 src_addr_a += offset_row_a;
1188 src_addr_b += offset_row_b;
1189
1190 // Reset accumulators
1191 float8 c00 = 0.0f;
1192 float8 c10 = 0.0f;
1193 float8 c20 = 0.0f;
1194 float8 c30 = 0.0f;
1195
1196 for(; src_addr_b <= (src_end_addr_b - (int)(16 * MULT_TRANSPOSE1XW_WIDTH)); src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 16 * MULT_TRANSPOSE1XW_WIDTH)
1197 {
1198 // Load values from matrix A (interleaved) and matrix B (transposed)
1199 float4 a0 = convert_float4(vload4(0, src_addr_a));
1200 float8 b0 = convert_float8(vload8(0, src_addr_b));
1201
1202 c00 += (float8)a0.s0 * b0;
1203 c10 += (float8)a0.s1 * b0;
1204 c20 += (float8)a0.s2 * b0;
1205 c30 += (float8)a0.s3 * b0;
1206
1207 // Load values from matrix A (interleaved) and matrix B (transposed)
1208 a0 = convert_float4(vload4(0, src_addr_a + 4 * MULT_INTERLEAVE4X4_HEIGHT));
1209 b0 = convert_float8(vload8(0, src_addr_b + 8 * MULT_TRANSPOSE1XW_WIDTH));
1210
1211 c00 += (float8)a0.s0 * b0;
1212 c10 += (float8)a0.s1 * b0;
1213 c20 += (float8)a0.s2 * b0;
1214 c30 += (float8)a0.s3 * b0;
1215 }
1216
1217 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT, src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH)
1218 {
1219 // Load values from matrix A (interleaved) and matrix B (transposed)
1220 float4 a0 = convert_float4(vload4(0, src_addr_a));
1221 float8 b0 = convert_float8(vload8(0, src_addr_b));
1222
1223 c00 += (float8)a0.s0 * b0;
1224 c10 += (float8)a0.s1 * b0;
1225 c20 += (float8)a0.s2 * b0;
1226 c30 += (float8)a0.s3 * b0;
1227 }
1228
1229 // Compute destination address
1230 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1231
1232#if defined(ALPHA)
1233 // Multiply by the weight of matrix product
1234 c00 = c00 * (float8)ALPHA;
1235 c10 = c10 * (float8)ALPHA;
1236 c20 = c20 * (float8)ALPHA;
1237 c30 = c30 * (float8)ALPHA;
1238#endif // defined(ALPHA)
1239
1240 // Compute dst address
1241 __global uchar *dst_addr = offset(&dst, 0, 0);
1242
1243#if defined(REINTERPRET_OUTPUT_AS_3D)
1244 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1245 // in order to take into account the presence of possible cross plane paddings
1246 //
1247 // | |
1248 // | plane0 |
1249 // | |
1250 // |__________________|
1251 // |******************|
1252 // | cross_plane_pad |
1253 // |******************|
1254 // | |
1255 // | plane1 |
1256 // | |
1257 // |__________________|
1258
1259 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1260 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1261 zout = min(DEPTH_GEMM3D - 1, zout);
1262
1263 // Add offset due to the cross plane paddings
1264 zout *= (cross_plane_pad * dst_stride_y);
1265
1266 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1267 // multiply dst_stride_z by DEPTH_GEMM3D
1268 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1269
1270 // Store 4x8 block
1271 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
1272 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
1273 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
1274 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
1275
1276#else // defined(REINTERPRET_OUTPUT_AS_3D)
1277 // Add offset for batched GEMM
1278 dst_addr += z * dst_stride_z;
1279
1280 // Store 4x8 block
1281 vstore8(convert_half8(c00), 0, (__global half *)(dst_addr + 0 * dst_stride_y));
1282 vstore8(convert_half8(c10), 0, (__global half *)(dst_addr + 1 * dst_stride_y));
1283 vstore8(convert_half8(c20), 0, (__global half *)(dst_addr + 2 * dst_stride_y));
1284 vstore8(convert_half8(c30), 0, (__global half *)(dst_addr + 3 * dst_stride_y));
1285#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1286}
1287
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001288/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A (src0) and matrix B (src1)
1289 * Matrix A and matrix B must be reshaped respectively with @ref gemm_interleave4x4_16bit and @ref gemm_transpose1x8 before running the matrix multiplication
1290 *
1291 * @note The number of columns of matrix B and the optional alpha's value need to be passed at compile time using -DCOLS_B and -DALPHA
1292 * @note The multiplication factor for the transposition width (mult_transpose1xW_width) must be passed at compile time using -DMULT_TRANSPOSE1XW_WIDTH (i.e. -DMULT_TRANSPOSE1XW_WIDTH=2)
1293 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DMULT_INTERLEAVE4X4_HEIGHT (i.e. -DMULT_INTERLEAVE4X4_HEIGHT=2)
1294 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1295 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
1296 *
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001297 * @note In case the output has to be reinterpreted as a 3D tensor (i.e. output of convolution layer), the following information must be passed at compile time:
1298 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1299 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1300 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1301 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1302 *
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001303 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1304 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1305 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1306 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1307 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1308 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1309 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1310 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1311 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1312 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1313 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1314 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1315 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1316 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1317 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1318 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1319 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1320 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001321 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001322 */
1323__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
1324 IMAGE_DECLARATION(src1),
1325 IMAGE_DECLARATION(dst),
1326 uint src0_stride_z,
1327 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001328 uint dst_stride_z
1329#if defined(REINTERPRET_OUTPUT_AS_3D)
1330 ,
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001331 uint cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001332#endif // REINTERPRET_OUTPUT_AS_3D
1333 )
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001334{
1335 int x = get_global_id(0) / MULT_TRANSPOSE1XW_WIDTH;
1336 int y = get_global_id(1) / MULT_INTERLEAVE4X4_HEIGHT;
1337 int z = get_global_id(2);
1338
1339 // Offset
1340 const int offset_row_a = (get_global_id(1) % MULT_INTERLEAVE4X4_HEIGHT) * 4;
1341 const int offset_row_b = (get_global_id(0) % MULT_TRANSPOSE1XW_WIDTH) * 8;
1342
1343 // src_addr_a = address of matrix A
1344 // src_addr_b = address of matrix B
1345 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1346 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1347
1348#if defined(MATRIX_B_DEPTH)
1349 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1350 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1351#else // defined(MATRIX_B_DEPTH)
1352 src1_addr_in_bytes += z * src1_stride_z;
1353#endif // defined(MATRIX_B_DEPTH)
1354
1355 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
1356 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
1357
1358 // Compute end row address for matrix B
1359 __global half *src_end_addr_b = src_addr_b + COLS_B;
1360
1361 src_addr_a += offset_row_a;
1362 src_addr_b += offset_row_b;
1363
1364 // Reset accumulators
1365 half8 c00 = 0.0f;
1366 half8 c10 = 0.0f;
1367 half8 c20 = 0.0f;
1368 half8 c30 = 0.0f;
1369
1370#define COLS_MTX_B (COLS_B / (8 * MULT_TRANSPOSE1XW_WIDTH))
1371
1372 int i = 0;
1373 for(; i <= (int)(COLS_MTX_B - 4); i += 4)
1374 {
1375#if MULT_INTERLEAVE4X4_HEIGHT == 1
1376 // Load values from matrix A (interleaved) and matrix B (transposed)
1377 half8 a0 = vload8(0, src_addr_a);
1378 half8 b0 = vload8(0, src_addr_b);
1379
1380 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
1381 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1382
1383 c00 = fma((half8)a0.s0, b0, c00);
1384 c10 = fma((half8)a0.s1, b0, c10);
1385 c20 = fma((half8)a0.s2, b0, c20);
1386 c30 = fma((half8)a0.s3, b0, c30);
1387
1388 // Load values from matrix B (transposed)
1389 b0 = vload8(0, src_addr_b);
1390
1391 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1392
1393 c00 = fma((half8)a0.s4, b0, c00);
1394 c10 = fma((half8)a0.s5, b0, c10);
1395 c20 = fma((half8)a0.s6, b0, c20);
1396 c30 = fma((half8)a0.s7, b0, c30);
1397
1398 // Load values from matrix A (interleaved) and matrix B (transposed)
1399 a0 = vload8(0, src_addr_a);
1400 b0 = vload8(0, src_addr_b);
1401
1402 src_addr_a += 8 * MULT_INTERLEAVE4X4_HEIGHT;
1403 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1404
1405 c00 = fma((half8)a0.s0, b0, c00);
1406 c10 = fma((half8)a0.s1, b0, c10);
1407 c20 = fma((half8)a0.s2, b0, c20);
1408 c30 = fma((half8)a0.s3, b0, c30);
1409
1410 // Load values from matrix B (transposed)
1411 b0 = vload8(0, src_addr_b);
1412
1413 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1414
1415 c00 = fma((half8)a0.s4, b0, c00);
1416 c10 = fma((half8)a0.s5, b0, c10);
1417 c20 = fma((half8)a0.s6, b0, c20);
1418 c30 = fma((half8)a0.s7, b0, c30);
1419#else // MULT_INTERLEAVE4X4_HEIGHT == 1
1420 // Load values from matrix A (interleaved) and matrix B (transposed)
1421 half4 a0 = vload4(0, src_addr_a);
1422 half8 b0 = vload8(0, src_addr_b);
1423
1424 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1425 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1426
1427 c00 = fma((half8)a0.s0, b0, c00);
1428 c10 = fma((half8)a0.s1, b0, c10);
1429 c20 = fma((half8)a0.s2, b0, c20);
1430 c30 = fma((half8)a0.s3, b0, c30);
1431
1432 // Load values from matrix A (interleaved) and matrix B (transposed)
1433 a0 = vload4(0, src_addr_a);
1434 b0 = vload8(0, src_addr_b);
1435
1436 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1437 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1438
1439 c00 = fma((half8)a0.s0, b0, c00);
1440 c10 = fma((half8)a0.s1, b0, c10);
1441 c20 = fma((half8)a0.s2, b0, c20);
1442 c30 = fma((half8)a0.s3, b0, c30);
1443
1444 // Load values from matrix A (interleaved) and matrix B (transposed)
1445 a0 = vload4(0, src_addr_a);
1446 b0 = vload8(0, src_addr_b);
1447
1448 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1449 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1450
1451 c00 = fma((half8)a0.s0, b0, c00);
1452 c10 = fma((half8)a0.s1, b0, c10);
1453 c20 = fma((half8)a0.s2, b0, c20);
1454 c30 = fma((half8)a0.s3, b0, c30);
1455
1456 // Load values from matrix A (interleaved) and matrix B (transposed)
1457 a0 = vload4(0, src_addr_a);
1458 b0 = vload8(0, src_addr_b);
1459
1460 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1461 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1462
1463 c00 = fma((half8)a0.s0, b0, c00);
1464 c10 = fma((half8)a0.s1, b0, c10);
1465 c20 = fma((half8)a0.s2, b0, c20);
1466 c30 = fma((half8)a0.s3, b0, c30);
1467#endif // MULT_INTERLEAVE4X4_HEIGHT == 1
1468 }
1469
1470 for(; i < (int)(COLS_MTX_B); ++i)
1471 {
1472 // Load values from matrix A (interleaved) and matrix B (transposed)
1473 half4 a0 = vload4(0, src_addr_a);
1474 half8 b0 = vload8(0, src_addr_b);
1475
1476 src_addr_a += 4 * MULT_INTERLEAVE4X4_HEIGHT;
1477 src_addr_b += 8 * MULT_TRANSPOSE1XW_WIDTH;
1478
1479 c00 = fma((half8)a0.s0, b0, c00);
1480 c10 = fma((half8)a0.s1, b0, c10);
1481 c20 = fma((half8)a0.s2, b0, c20);
1482 c30 = fma((half8)a0.s3, b0, c30);
1483 }
1484
1485 // Compute destination address
1486 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1487
1488#if defined(ALPHA)
1489 // Multiply by the weight of matrix product
1490 c00 = c00 * (half8)ALPHA;
1491 c10 = c10 * (half8)ALPHA;
1492 c20 = c20 * (half8)ALPHA;
1493 c30 = c30 * (half8)ALPHA;
1494#endif // defined(ALPHA)
1495
1496 // Compute dst address
1497 __global uchar *dst_addr = offset(&dst, 0, 0);
1498
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001499#if defined(REINTERPRET_OUTPUT_AS_3D)
1500 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001501 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001502 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001503 // | |
1504 // | plane0 |
1505 // | |
1506 // |__________________|
1507 // |******************|
1508 // | cross_plane_pad |
1509 // |******************|
1510 // | |
1511 // | plane1 |
1512 // | |
1513 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001514
1515 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1516 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1517 zout = min(DEPTH_GEMM3D - 1, zout);
1518
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001519 // Add offset due to the cross plane paddings
1520 zout *= (cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001521
1522 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1523 // multiply dst_stride_z by DEPTH_GEMM3D
1524 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1525
1526 // Store 4x8 block
1527 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
1528 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
1529 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
1530 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
1531
1532#else // defined(REINTERPRET_OUTPUT_AS_3D)
1533 // Add offset for batched GEMM
1534 dst_addr += z * dst_stride_z;
1535
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001536 // Store 4x8 block
1537 vstore8(c00, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
1538 vstore8(c10, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
1539 vstore8(c20, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
1540 vstore8(c30, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001541#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicebb36a8e2018-04-19 12:05:08 +01001542}
Georgios Pinitas84225582018-05-14 12:00:05 +01001543
1544// Undefine local defines
1545#undef COLS_MTX_B
1546
Matthew Bentham6f31f8c2017-10-27 11:50:06 +01001547#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001548
Gian Marco36a0a462018-01-12 10:21:40 +00001549#endif // defined(COLS_B) && defined(MULT_TRANSPOSE1XW_WIDTH) && defined(MULT_INTERLEAVE4X4_HEIGHT)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001550
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001551#if defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
1552#if defined(DATA_TYPE)
1553#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001554/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001555 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001556 * @note This OpenCL kernel works with floating point data types (F16/F32)
1557 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1558 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001559 * @note The number of matrix A columns and the optional alpha's value need to be passed at compile time using -DCOLS_A and -DALPHA
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001560 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1561 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001562 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001563 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1564 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001565 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1566 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1567 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1568 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1569 *
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001570 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001571 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1572 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1573 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1574 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1575 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001576 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001577 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1578 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1579 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1580 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1581 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001582 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001583 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1584 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1585 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1586 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1587 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001588 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1589 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1590 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001591 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1592 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001593 */
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001594__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1595 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001596 IMAGE_DECLARATION(dst),
1597 uint src0_stride_z,
1598 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001599 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001600#if defined(REINTERPRET_INPUT_AS_3D)
1601 ,
1602 uint src_cross_plane_pad
1603#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001604#if defined(REINTERPRET_OUTPUT_AS_3D)
1605 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001606 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001607#endif // REINTERPRET_OUTPUT_AS_3D
1608 )
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001609{
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001610 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001611
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001612 // Compute starting address for matrix A and Matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001613 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001614
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001615 // Update address for the matrix A
1616 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001617
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001618 // Update address for the matrix B
1619 src_addr.s1 += idx * sizeof(DATA_TYPE);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001620
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001621#if defined(REINTERPRET_INPUT_AS_3D)
1622 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1623 // in order to take into account the presence of possible cross plane paddings
1624 //
1625 // | |
1626 // | plane0 |
1627 // | |
1628 // |__________________|
1629 // |******************|
1630 // | cross_plane_pad |
1631 // |******************|
1632 // | |
1633 // | plane1 |
1634 // | |
1635 // |__________________|
1636
1637 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1638 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1639 zin = min(DEPTH_GEMM3D - 1, zin);
1640
1641 // Add offset due to the cross plane paddings
1642 zin *= (src_cross_plane_pad * src0_stride_y);
1643
1644 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1645 // multiply src0_stride_z by DEPTH_GEMM3D
1646 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1647
1648#else // defined(REINTERPRET_INPUT_AS_3D)
1649
Gian Marcoae2af742018-02-15 12:35:44 +00001650 // Add offset for batched GEMM
1651 src_addr.s0 += get_global_id(2) * src0_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001652
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001653#endif // defined(REINTERPRET_INPUT_AS_3D)
1654
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001655#if defined(MATRIX_B_DEPTH)
1656 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1657 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1658#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001659 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001660#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001661
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001662 int end_row_vec_a = src_addr.s0 + (COLS_A * sizeof(DATA_TYPE));
1663
1664 VECTOR_TYPE acc0 = 0.0f;
1665#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1666 VECTOR_TYPE acc1 = 0.0f;
1667#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1668#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1669 VECTOR_TYPE acc2 = 0.0f;
1670#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1671#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1672 VECTOR_TYPE acc3 = 0.0f;
1673#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1674
Georgios Pinitas96880cf2017-10-20 18:52:20 +01001675 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001676 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001677#if defined(REINTERPRET_INPUT_AS_3D)
1678 // Load values from matrix A
1679 VEC_DATA_TYPE(DATA_TYPE, 2)
1680 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1681#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1682 VEC_DATA_TYPE(DATA_TYPE, 2)
1683 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1684#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1685#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1686 VEC_DATA_TYPE(DATA_TYPE, 2)
1687 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1688#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1689#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1690 VEC_DATA_TYPE(DATA_TYPE, 2)
1691 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1692#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1693#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001694 // Load values from matrix A
1695 VEC_DATA_TYPE(DATA_TYPE, 2)
1696 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1697#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1698 VEC_DATA_TYPE(DATA_TYPE, 2)
1699 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1700#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1701#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1702 VEC_DATA_TYPE(DATA_TYPE, 2)
1703 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1704#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1705#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1706 VEC_DATA_TYPE(DATA_TYPE, 2)
1707 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1708#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001709#endif // defined(REINTERPRET_INPUT_AS_3D)
1710
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001711 // Load values from matrix B
1712 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1713 VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001714
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001715 // Accumulate
1716 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1717 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1718#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1719 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1720 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1721#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1722#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1723 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1724 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1725#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1726#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1727 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1728 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1729#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001730 }
1731
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001732 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001733 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001734#if defined(REINTERPRET_INPUT_AS_3D)
1735 // Load values from matrix A
1736 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1737#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1738 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1739#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1740#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1741 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1742#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1743#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1744 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1745#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1746#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001747 // Load values from matrix A
1748 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1749#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1750 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1751#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1752#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1753 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1754#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1755#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1756 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1757#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001758#endif // defined(REINTERPRET_INPUT_AS_3D)
1759
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001760 // Load values from matrix B
1761 VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001762
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001763 // Accumulate
1764 acc0 += b0 * (VECTOR_TYPE)a0;
1765#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1766 acc1 += b0 * (VECTOR_TYPE)a1;
1767#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1768#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1769 acc2 += b0 * (VECTOR_TYPE)a2;
1770#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1771#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1772 acc3 += b0 * (VECTOR_TYPE)a3;
1773#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001774 }
1775
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001776 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001777 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1778
Gian Marcoae2af742018-02-15 12:35:44 +00001779 // Compute dst address
1780 __global uchar *dst_addr = offset(&dst, 0, 0);
1781
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001782 // Multiply by the weight of matrix-matrix product and store the result
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001783#if defined(ALPHA)
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001784 acc0 = acc0 * (VECTOR_TYPE)ALPHA;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001785#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001786#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1787 acc1 = acc1 * (VECTOR_TYPE)ALPHA;
1788#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
1789#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1790 acc2 = acc2 * (VECTOR_TYPE)ALPHA;
1791#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
1792#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1793 acc3 = acc3 * (VECTOR_TYPE)ALPHA;
1794#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
1795
1796 int z = get_global_id(2);
1797
1798#if defined(REINTERPRET_OUTPUT_AS_3D)
1799 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001800 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001801 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001802 // | |
1803 // | plane0 |
1804 // | |
1805 // |__________________|
1806 // |******************|
1807 // | cross_plane_pad |
1808 // |******************|
1809 // | |
1810 // | plane1 |
1811 // | |
1812 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001813
1814 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1815 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1816 zout = min(DEPTH_GEMM3D - 1, zout);
1817
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01001818 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001819 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001820
1821 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1822 // multiply dst_stride_z by DEPTH_GEMM3D
1823 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1824
1825 // Store output block
1826 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1827 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y + zout.s0));
1828#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1829 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1830 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y + zout.s1));
1831#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1832#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1833 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1834 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y + zout.s2));
1835#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1836#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1837 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
1838 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y + zout.s3));
1839#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1840
1841#else // defined(REINTERPRET_OUTPUT_AS_3D)
1842 // Add offset for batched GEMM
1843 dst_addr += z * dst_stride_z;
1844
1845 // Store output block
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001846 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001847 (acc0, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001848#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001849 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001850 (acc1, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001851#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1852#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001853 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001854 (acc2, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001855#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1856#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001857 VSTORE(NUM_ELEMS_PROCESSED_PER_THREAD_X)
Gian Marcoae2af742018-02-15 12:35:44 +00001858 (acc3, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001859#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001860#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01001861}
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01001862#endif // defined(DATA_TYPE)
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01001863
Michele Di Giorgiof6f08da2018-04-26 10:24:30 +01001864/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001865 *
1866 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1867 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
1868 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
1869 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
1870 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001871 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
1872 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001873 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001874 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1875 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001876 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1877 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1878 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1879 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1880 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001881 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1882 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1883 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1884 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1885 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1886 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1887 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1888 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1889 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1890 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1891 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1892 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1893 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1894 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1895 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1896 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1897 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1898 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001899 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1900 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1901 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001902 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1903 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001904 */
1905__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1906 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00001907 IMAGE_DECLARATION(dst),
1908 uint src0_stride_z,
1909 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001910 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001911#if defined(REINTERPRET_INPUT_AS_3D)
1912 ,
1913 uint src_cross_plane_pad
1914#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001915#if defined(REINTERPRET_OUTPUT_AS_3D)
1916 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001917 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00001918#endif // REINTERPRET_OUTPUT_AS_3D
1919 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001920{
1921 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
1922
1923 // Compute starting address for matrix A and matrix B
1924 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1925
1926 // Update address for matrix A
1927 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
1928
1929 // Update address for matrix B
1930 src_addr.s1 += idx * sizeof(float);
1931
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001932#if defined(REINTERPRET_INPUT_AS_3D)
1933 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1934 // in order to take into account the presence of possible cross plane paddings
1935 //
1936 // | |
1937 // | plane0 |
1938 // | |
1939 // |__________________|
1940 // |******************|
1941 // | cross_plane_pad |
1942 // |******************|
1943 // | |
1944 // | plane1 |
1945 // | |
1946 // |__________________|
1947
1948 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
1949 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
1950 zin = min(DEPTH_GEMM3D - 1, zin);
1951
1952 // Add offset due to the cross plane paddings
1953 zin *= (src_cross_plane_pad * src0_stride_y);
1954
1955 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1956 // multiply src0_stride_z by DEPTH_GEMM3D
1957 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1958
1959#else // defined(REINTERPRET_INPUT_AS_3D)
1960
Gian Marcoae2af742018-02-15 12:35:44 +00001961 // Add offset for batched GEMM
1962 src_addr.s0 += get_global_id(2) * src0_stride_z;
1963
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01001964#endif // defined(REINTERPRET_INPUT_AS_3D)
1965
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001966#if defined(MATRIX_B_DEPTH)
1967 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1968 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1969#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001970 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001971#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00001972
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00001973 // Initialize accumulators
1974 float acc00 = 0.0f;
1975 float acc01 = 0.0f;
1976 float acc02 = 0.0f;
1977 float acc03 = 0.0f;
1978
1979#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1980 float acc10 = 0.0f;
1981 float acc11 = 0.0f;
1982 float acc12 = 0.0f;
1983 float acc13 = 0.0f;
1984#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
1985
1986#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1987 float acc20 = 0.0f;
1988 float acc21 = 0.0f;
1989 float acc22 = 0.0f;
1990 float acc23 = 0.0f;
1991#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
1992
1993#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1994 float acc30 = 0.0f;
1995 float acc31 = 0.0f;
1996 float acc32 = 0.0f;
1997 float acc33 = 0.0f;
1998#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
1999
2000 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002001 int i = 0;
2002 for(; i <= ((int)COLS_A - 4); i += 4)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002003 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002004#if defined(REINTERPRET_INPUT_AS_3D)
2005 // Load values from matrix A and matrix B
2006 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2007#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2008 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2009#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2010#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2011 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2012#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2013#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2014 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2015#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2016#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002017 // Load values from matrix A and matrix B
2018 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002019#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002020 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002021#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2022#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002023 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002024#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2025#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002026 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002027#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002028#endif // defined(REINTERPRET_INPUT_AS_3D)
2029
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002030 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2031 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002032
2033 // Multiply and accumulate
2034 acc00 = fma(a0.s0, b0.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002035 acc01 = fma(a0.s0, b0.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002036 acc02 = fma(a0.s0, b0.s2, acc02);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002037 acc03 = fma(a0.s0, b0.s3, acc03);
2038
2039#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002040
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002041 acc10 = fma(a1.s0, b0.s0, acc10);
2042 acc11 = fma(a1.s0, b0.s1, acc11);
2043 acc12 = fma(a1.s0, b0.s2, acc12);
2044 acc13 = fma(a1.s0, b0.s3, acc13);
2045
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002046#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2047#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002048
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002049 acc20 = fma(a2.s0, b0.s0, acc20);
2050 acc21 = fma(a2.s0, b0.s1, acc21);
2051 acc22 = fma(a2.s0, b0.s2, acc22);
2052 acc23 = fma(a2.s0, b0.s3, acc23);
2053
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002054#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2055#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002056
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002057 acc30 = fma(a3.s0, b0.s0, acc30);
2058 acc31 = fma(a3.s0, b0.s1, acc31);
2059 acc32 = fma(a3.s0, b0.s2, acc32);
2060 acc33 = fma(a3.s0, b0.s3, acc33);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002061#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002062
2063 // Load values from matrix A and matrix B
2064 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2065 src_addr.s1 += src1_stride_y;
2066
2067 // Multiply and accumulate
2068 acc00 = fma(a0.s1, b0.s0, acc00);
2069 acc01 = fma(a0.s1, b0.s1, acc01);
2070 acc02 = fma(a0.s1, b0.s2, acc02);
2071 acc03 = fma(a0.s1, b0.s3, acc03);
2072
2073#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2074
2075 acc10 = fma(a1.s1, b0.s0, acc10);
2076 acc11 = fma(a1.s1, b0.s1, acc11);
2077 acc12 = fma(a1.s1, b0.s2, acc12);
2078 acc13 = fma(a1.s1, b0.s3, acc13);
2079
2080#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2081#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2082
2083 acc20 = fma(a2.s1, b0.s0, acc20);
2084 acc21 = fma(a2.s1, b0.s1, acc21);
2085 acc22 = fma(a2.s1, b0.s2, acc22);
2086 acc23 = fma(a2.s1, b0.s3, acc23);
2087
2088#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2089#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2090
2091 acc30 = fma(a3.s1, b0.s0, acc30);
2092 acc31 = fma(a3.s1, b0.s1, acc31);
2093 acc32 = fma(a3.s1, b0.s2, acc32);
2094 acc33 = fma(a3.s1, b0.s3, acc33);
2095#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2096
2097 // Load values from matrix A and matrix B
2098 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2099 src_addr.s1 += src1_stride_y;
2100
2101 // Multiply and accumulate
2102 acc00 = fma(a0.s2, b0.s0, acc00);
2103 acc01 = fma(a0.s2, b0.s1, acc01);
2104 acc02 = fma(a0.s2, b0.s2, acc02);
2105 acc03 = fma(a0.s2, b0.s3, acc03);
2106
2107#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2108
2109 acc10 = fma(a1.s2, b0.s0, acc10);
2110 acc11 = fma(a1.s2, b0.s1, acc11);
2111 acc12 = fma(a1.s2, b0.s2, acc12);
2112 acc13 = fma(a1.s2, b0.s3, acc13);
2113
2114#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2115#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2116
2117 acc20 = fma(a2.s2, b0.s0, acc20);
2118 acc21 = fma(a2.s2, b0.s1, acc21);
2119 acc22 = fma(a2.s2, b0.s2, acc22);
2120 acc23 = fma(a2.s2, b0.s3, acc23);
2121
2122#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2123#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2124
2125 acc30 = fma(a3.s2, b0.s0, acc30);
2126 acc31 = fma(a3.s2, b0.s1, acc31);
2127 acc32 = fma(a3.s2, b0.s2, acc32);
2128 acc33 = fma(a3.s2, b0.s3, acc33);
2129#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2130
2131 // Load values from matrix A and matrix B
2132 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
2133 src_addr.s1 += src1_stride_y;
2134
2135 // Multiply and accumulate
2136 acc00 = fma(a0.s3, b0.s0, acc00);
2137 acc01 = fma(a0.s3, b0.s1, acc01);
2138 acc02 = fma(a0.s3, b0.s2, acc02);
2139 acc03 = fma(a0.s3, b0.s3, acc03);
2140
2141#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2142
2143 acc10 = fma(a1.s3, b0.s0, acc10);
2144 acc11 = fma(a1.s3, b0.s1, acc11);
2145 acc12 = fma(a1.s3, b0.s2, acc12);
2146 acc13 = fma(a1.s3, b0.s3, acc13);
2147
2148#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2149#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2150
2151 acc20 = fma(a2.s3, b0.s0, acc20);
2152 acc21 = fma(a2.s3, b0.s1, acc21);
2153 acc22 = fma(a2.s3, b0.s2, acc22);
2154 acc23 = fma(a2.s3, b0.s3, acc23);
2155
2156#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2157#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2158
2159 acc30 = fma(a3.s3, b0.s0, acc30);
2160 acc31 = fma(a3.s3, b0.s1, acc31);
2161 acc32 = fma(a3.s3, b0.s2, acc32);
2162 acc33 = fma(a3.s3, b0.s3, acc33);
2163#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2164
2165 src_addr.s0 += 4 * sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002166 }
2167
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002168 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002169 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002170#if defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002171 // Load values from matrix A
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002172 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2173#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2174 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2175#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2176#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2177 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2178#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2179#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2180 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2181#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2182#else // defined(REINTERPRET_INPUT_AS_3D)
2183 // Load values from matrix A
2184 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002185#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2186 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2187#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2188#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2189 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2190#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2191#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2192 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2193#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002194#endif // defined(REINTERPRET_INPUT_AS_3D)
2195
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002196 // Load values from matrix B
2197 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002198 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002199
2200 // Multiply and accumulate
2201 acc00 = fma(a0, b0.s0, acc00);
2202 acc01 = fma(a0, b0.s1, acc01);
2203 acc02 = fma(a0, b0.s2, acc02);
2204 acc03 = fma(a0, b0.s3, acc03);
2205#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2206 acc10 = fma(a1, b0.s0, acc10);
2207 acc11 = fma(a1, b0.s1, acc11);
2208 acc12 = fma(a1, b0.s2, acc12);
2209 acc13 = fma(a1, b0.s3, acc13);
2210#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2211#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2212 acc20 = fma(a2, b0.s0, acc20);
2213 acc21 = fma(a2, b0.s1, acc21);
2214 acc22 = fma(a2, b0.s2, acc22);
2215 acc23 = fma(a2, b0.s3, acc23);
2216#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2217#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2218 acc30 = fma(a3, b0.s0, acc30);
2219 acc31 = fma(a3, b0.s1, acc31);
2220 acc32 = fma(a3, b0.s2, acc32);
2221 acc33 = fma(a3, b0.s3, acc33);
2222#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002223
2224 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002225 }
2226
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002227 int z = get_global_id(2);
2228
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002229 // Compute destination address
2230 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2231
2232 // Multiply by the weight of matrix-matrix product and store the result
2233#if defined(ALPHA)
2234 acc00 = acc00 * ALPHA;
2235 acc01 = acc01 * ALPHA;
2236 acc02 = acc02 * ALPHA;
2237 acc03 = acc03 * ALPHA;
2238#endif // defined(ALPHA)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002239#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002240 acc10 = acc10 * ALPHA;
2241 acc11 = acc11 * ALPHA;
2242 acc12 = acc12 * ALPHA;
2243 acc13 = acc13 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002244#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2245#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002246 acc20 = acc20 * ALPHA;
2247 acc21 = acc21 * ALPHA;
2248 acc22 = acc22 * ALPHA;
2249 acc23 = acc23 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002250#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002252 acc30 = acc30 * ALPHA;
2253 acc31 = acc31 * ALPHA;
2254 acc32 = acc32 * ALPHA;
2255 acc33 = acc33 * ALPHA;
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002256#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2257
2258 // Compute dst address
2259 __global uchar *dst_addr = offset(&dst, 0, 0);
2260
2261#if defined(REINTERPRET_OUTPUT_AS_3D)
2262 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002263 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002264 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002265 // | |
2266 // | plane0 |
2267 // | |
2268 // |__________________|
2269 // |******************|
2270 // | cross_plane_pad |
2271 // |******************|
2272 // | |
2273 // | plane1 |
2274 // | |
2275 // |__________________|
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002276
2277 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2278 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2279 zout = min(DEPTH_GEMM3D - 1, zout);
2280
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002281 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002282 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002283
2284 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2285 // multiply dst_stride_z by DEPTH_GEMM3D
2286 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2287
2288 // Store the output block
2289 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
2290#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2291 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
2292#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2293#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2294 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
2295#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2296#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2297 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002298#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002299
2300#else // defined(REINTERPRET_OUTPUT_AS_3D)
2301 // Add offset for batched GEMM
2302 dst_addr += z * dst_stride_z;
2303
2304 // Store the output block
2305 vstore4((float4)(acc00, acc01, acc02, acc03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2306#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2307 vstore4((float4)(acc10, acc11, acc12, acc13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2308#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2309#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2310 vstore4((float4)(acc20, acc21, acc22, acc23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2311#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2312#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2313 vstore4((float4)(acc30, acc31, acc32, acc33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
2314#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2315#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002316}
2317
2318/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
2319 *
2320 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
2321 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
2322 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
2323 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=2.
2324 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
2325 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha if alpha!=1.0f.
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002326 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2327 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002328 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002329 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2330 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002331 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2332 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2333 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2334 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2335 *
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002336 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
2337 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2338 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2339 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2340 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2341 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2342 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2343 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2344 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2345 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2346 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2347 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2348 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2349 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2350 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2351 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2352 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2353 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002354 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2355 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2356 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002357 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2358 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002359 */
2360__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
2361 IMAGE_DECLARATION(src1),
Gian Marcoae2af742018-02-15 12:35:44 +00002362 IMAGE_DECLARATION(dst),
2363 uint src0_stride_z,
2364 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002365 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002366#if defined(REINTERPRET_INPUT_AS_3D)
2367 ,
2368 uint src_cross_plane_pad
2369#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002370#if defined(REINTERPRET_OUTPUT_AS_3D)
2371 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002372 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002373#endif // REINTERPRET_OUTPUT_AS_3D
2374 )
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002375{
2376 // Requires 2 NUM_ELEMS_PROCESSED_PER_THREAD_X, C vect2, A vect4, B (2 vload2) // to fix for NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2377 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2378
2379 // Compute starting address for matrix A and Matrix B
2380 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2381
2382 // Update address for the matrix A
2383 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2384
2385 // Update address for the matrix B
2386 src_addr.s1 += idx * sizeof(float);
2387
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002388#if defined(REINTERPRET_INPUT_AS_3D)
2389 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2390 // in order to take into account the presence of possible cross plane paddings
2391 //
2392 // | |
2393 // | plane0 |
2394 // | |
2395 // |__________________|
2396 // |******************|
2397 // | cross_plane_pad |
2398 // |******************|
2399 // | |
2400 // | plane1 |
2401 // | |
2402 // |__________________|
2403
2404 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2405 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2406 zin = min(DEPTH_GEMM3D - 1, zin);
2407
2408 // Add offset due to the cross plane paddings
2409 zin *= (src_cross_plane_pad * src0_stride_y);
2410
2411 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2412 // multiply src0_stride_z by DEPTH_GEMM3D
2413 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2414
2415#else // defined(REINTERPRET_INPUT_AS_3D)
2416
Gian Marcoae2af742018-02-15 12:35:44 +00002417 // Add offset for batched GEMM
2418 src_addr.s0 += get_global_id(2) * src0_stride_z;
2419
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002420#endif // defined(REINTERPRET_INPUT_AS_3D)
2421
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002422#if defined(MATRIX_B_DEPTH)
2423 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2424 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2425#else // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002426 src_addr.s1 += get_global_id(2) * src1_stride_z;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002427#endif // defined(MATRIX_B_DEPTH)
Gian Marcoae2af742018-02-15 12:35:44 +00002428
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002429 // Initialize accumulators
2430 float acc00 = 0.0f;
2431 float acc01 = 0.0f;
2432
2433#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2434 float acc10 = 0.0f;
2435 float acc11 = 0.0f;
2436#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2437#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2438 float acc20 = 0.0f;
2439 float acc21 = 0.0f;
2440#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2441#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2442 float acc30 = 0.0f;
2443 float acc31 = 0.0f;
2444#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2445
2446 // A and B src indices get incremented at the same time.
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002447 int i = 0;
2448 for(; i <= ((int)COLS_A - 8); i += 8)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002449 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002450#if defined(REINTERPRET_INPUT_AS_3D)
2451 // Load values from matrix A
2452 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
2453#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002454 // Load values from matrix A
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002455 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002456#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002457
2458 // Load values from matrix B
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002459 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2460 src_addr.s1 += src1_stride_y;
2461 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2462 src_addr.s1 += src1_stride_y;
2463 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2464 src_addr.s1 += src1_stride_y;
2465 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2466 src_addr.s1 += src1_stride_y;
2467 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2468 src_addr.s1 += src1_stride_y;
2469 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2470 src_addr.s1 += src1_stride_y;
2471 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2472 src_addr.s1 += src1_stride_y;
2473 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2474 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002475
2476 // Multiply and accumulate
2477 acc00 = fma(a0.s0, b0.s0, acc00);
2478 acc00 = fma(a0.s1, b1.s0, acc00);
2479 acc00 = fma(a0.s2, b2.s0, acc00);
2480 acc00 = fma(a0.s3, b3.s0, acc00);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002481 acc00 = fma(a0.s4, b4.s0, acc00);
2482 acc00 = fma(a0.s5, b5.s0, acc00);
2483 acc00 = fma(a0.s6, b6.s0, acc00);
2484 acc00 = fma(a0.s7, b7.s0, acc00);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002485
2486 acc01 = fma(a0.s0, b0.s1, acc01);
2487 acc01 = fma(a0.s1, b1.s1, acc01);
2488 acc01 = fma(a0.s2, b2.s1, acc01);
2489 acc01 = fma(a0.s3, b3.s1, acc01);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002490 acc01 = fma(a0.s4, b4.s1, acc01);
2491 acc01 = fma(a0.s5, b5.s1, acc01);
2492 acc01 = fma(a0.s6, b6.s1, acc01);
2493 acc01 = fma(a0.s7, b7.s1, acc01);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002494
2495#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002496#if defined(REINTERPRET_INPUT_AS_3D)
2497 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2498#else // defined(REINTERPRET_INPUT_AS_3D)
2499 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2500#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002501 acc10 = fma(a0.s0, b0.s0, acc10);
2502 acc10 = fma(a0.s1, b1.s0, acc10);
2503 acc10 = fma(a0.s2, b2.s0, acc10);
2504 acc10 = fma(a0.s3, b3.s0, acc10);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002505 acc10 = fma(a0.s4, b4.s0, acc10);
2506 acc10 = fma(a0.s5, b5.s0, acc10);
2507 acc10 = fma(a0.s6, b6.s0, acc10);
2508 acc10 = fma(a0.s7, b7.s0, acc10);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002509
2510 acc11 = fma(a0.s0, b0.s1, acc11);
2511 acc11 = fma(a0.s1, b1.s1, acc11);
2512 acc11 = fma(a0.s2, b2.s1, acc11);
2513 acc11 = fma(a0.s3, b3.s1, acc11);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002514 acc11 = fma(a0.s4, b4.s1, acc11);
2515 acc11 = fma(a0.s5, b5.s1, acc11);
2516 acc11 = fma(a0.s6, b6.s1, acc11);
2517 acc11 = fma(a0.s7, b7.s1, acc11);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002518#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2519#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002520#if defined(REINTERPRET_INPUT_AS_3D)
2521 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2522#else // defined(REINTERPRET_INPUT_AS_3D)
2523 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2524#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002525 acc20 = fma(a0.s0, b0.s0, acc20);
2526 acc20 = fma(a0.s1, b1.s0, acc20);
2527 acc20 = fma(a0.s2, b2.s0, acc20);
2528 acc20 = fma(a0.s3, b3.s0, acc20);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002529 acc20 = fma(a0.s4, b4.s0, acc20);
2530 acc20 = fma(a0.s5, b5.s0, acc20);
2531 acc20 = fma(a0.s6, b6.s0, acc20);
2532 acc20 = fma(a0.s7, b7.s0, acc20);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002533
2534 acc21 = fma(a0.s0, b0.s1, acc21);
2535 acc21 = fma(a0.s1, b1.s1, acc21);
2536 acc21 = fma(a0.s2, b2.s1, acc21);
2537 acc21 = fma(a0.s3, b3.s1, acc21);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002538 acc21 = fma(a0.s4, b4.s1, acc21);
2539 acc21 = fma(a0.s5, b5.s1, acc21);
2540 acc21 = fma(a0.s6, b6.s1, acc21);
2541 acc21 = fma(a0.s7, b7.s1, acc21);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002542#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2543#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002544#if defined(REINTERPRET_INPUT_AS_3D)
2545 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2546#else // defined(REINTERPRET_INPUT_AS_3D)
2547 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2548#endif // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002549 acc30 = fma(a0.s0, b0.s0, acc30);
2550 acc30 = fma(a0.s1, b1.s0, acc30);
2551 acc30 = fma(a0.s2, b2.s0, acc30);
2552 acc30 = fma(a0.s3, b3.s0, acc30);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002553 acc30 = fma(a0.s4, b4.s0, acc30);
2554 acc30 = fma(a0.s5, b5.s0, acc30);
2555 acc30 = fma(a0.s6, b6.s0, acc30);
2556 acc30 = fma(a0.s7, b7.s0, acc30);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002557
2558 acc31 = fma(a0.s0, b0.s1, acc31);
2559 acc31 = fma(a0.s1, b1.s1, acc31);
2560 acc31 = fma(a0.s2, b2.s1, acc31);
2561 acc31 = fma(a0.s3, b3.s1, acc31);
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002562 acc31 = fma(a0.s4, b4.s1, acc31);
2563 acc31 = fma(a0.s5, b5.s1, acc31);
2564 acc31 = fma(a0.s6, b6.s1, acc31);
2565 acc31 = fma(a0.s7, b7.s1, acc31);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002566#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002567
2568 src_addr.s0 += sizeof(float) * 8;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002569 }
2570 // float size increment
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002571 for(; i < (int)COLS_A; ++i)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002572 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002573#if defined(REINTERPRET_INPUT_AS_3D)
2574 // Load values from matrix A
2575 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2576#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2577 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2578#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2579#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2580 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2581#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2582#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2583 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2584#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2585#else // defined(REINTERPRET_INPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002586 // Load values from matrix A
2587 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2588#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2589 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2590#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2591#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2592 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2593#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2594#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2595 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2596#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002597#endif // defined(REINTERPRET_INPUT_AS_3D)
2598
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002599 // Load values from matrix B
2600 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002601 src_addr.s1 += src1_stride_y;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002602
2603 // Multiply and accumulate
2604 acc00 = fma(a0, b0.s0, acc00);
2605 acc01 = fma(a0, b0.s1, acc01);
2606#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2607 acc10 = fma(a1, b0.s0, acc10);
2608 acc11 = fma(a1, b0.s1, acc11);
2609#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2610#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2611 acc20 = fma(a2, b0.s0, acc20);
2612 acc21 = fma(a2, b0.s1, acc21);
2613#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2614#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2615 acc30 = fma(a3, b0.s0, acc30);
2616 acc31 = fma(a3, b0.s1, acc31);
2617#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicec9c62c22018-04-06 10:00:10 +01002618
2619 src_addr.s0 += sizeof(float);
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002620 }
2621
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002622 // Multiply by the weight of matrix-matrix product and store the result
2623#if defined(ALPHA)
2624 acc00 = acc00 * ALPHA;
2625 acc01 = acc01 * ALPHA;
2626#endif // defined(ALPHA)
2627#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2628 acc10 = acc10 * ALPHA;
2629 acc11 = acc11 * ALPHA;
2630#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
2631#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2632 acc20 = acc20 * ALPHA;
2633 acc21 = acc21 * ALPHA;
2634#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
2635#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2636 acc30 = acc30 * ALPHA;
2637 acc31 = acc31 * ALPHA;
2638#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
2639
2640 int z = get_global_id(2);
2641
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002642 // Compute destination address
2643 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2644
Gian Marcoae2af742018-02-15 12:35:44 +00002645 // Compute dst address
2646 __global uchar *dst_addr = offset(&dst, 0, 0);
2647
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002648#if defined(REINTERPRET_OUTPUT_AS_3D)
2649 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002650 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002651 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002652 // | |
2653 // | plane0 |
2654 // | |
2655 // |__________________|
2656 // |******************|
2657 // | cross_plane_pad |
2658 // |******************|
2659 // | |
2660 // | plane1 |
2661 // | |
2662 // |__________________|
Gian Marcoae2af742018-02-15 12:35:44 +00002663
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002664 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2665 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2666 zout = min(DEPTH_GEMM3D - 1, zout);
2667
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01002668 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01002669 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002670
2671 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2672 // multiply dst_stride_z by DEPTH_GEMM3D
2673 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2674
2675 // Store the output block
2676 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y + zout.s0));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002677#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002678 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y + zout.s1));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002679#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2680#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002681 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y + zout.s2));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002682#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2683#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002684 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y + zout.s3));
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002685#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00002686
2687#else // defined(REINTERPRET_OUTPUT_AS_3D)
2688 // Add offset for batched GEMM
2689 dst_addr += z * dst_stride_z;
2690
2691 // Store the output block
2692 vstore2((float2)(acc00, acc01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2693#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2694 vstore2((float2)(acc10, acc11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
2695#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2696#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2697 vstore2((float2)(acc20, acc21), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
2698#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2699#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2700 vstore2((float2)(acc30, acc31), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
2701#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2702#endif // defined(REINTERPRET_OUTPUT_AS_3D)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00002703}
2704
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01002705#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01002706/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2707 *
Vidhya Sudhan Loganathana25d16c2018-11-16 11:33:12 +00002708 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
2709 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
2710 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
2711 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
2712 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
2713 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
2714 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
2715 *
2716 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2717 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2718 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2719 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2720 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2721 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2722 *
2723 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2724 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2725 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2726 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2727 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2728 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2729 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2730 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2731 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2732 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2733 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2734 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2735 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2736 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2737 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2738 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2739 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2740 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2741 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2742 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2743 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2744 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2745 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2746 */
2747__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
2748 IMAGE_DECLARATION(src1),
2749 IMAGE_DECLARATION(dst),
2750 uint src0_stride_z,
2751 uint src1_stride_z,
2752 uint dst_stride_z
2753#if defined(REINTERPRET_INPUT_AS_3D)
2754 ,
2755 uint src_cross_plane_pad
2756#endif // REINTERPRET_INPUT_AS_3D
2757#if defined(REINTERPRET_OUTPUT_AS_3D)
2758 ,
2759 uint dst_cross_plane_pad
2760#endif // REINTERPRET_OUTPUT_AS_3D
2761 )
2762{
2763 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
2764
2765 // Compute starting address for matrix A and Matrix B
2766 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2767
2768 // Update address for the matrix A
2769 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
2770
2771 // Update address for the matrix B
2772 src_addr.s1 += idx * sizeof(half);
2773
2774#if defined(REINTERPRET_INPUT_AS_3D)
2775 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2776 // in order to take into account the presence of possible cross plane paddings
2777 //
2778 // | |
2779 // | plane0 |
2780 // | |
2781 // |__________________|
2782 // |******************|
2783 // | cross_plane_pad |
2784 // |******************|
2785 // | |
2786 // | plane1 |
2787 // | |
2788 // |__________________|
2789
2790 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
2791 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
2792 zin = min(DEPTH_GEMM3D - 1, zin);
2793
2794 // Add offset due to the cross plane paddings
2795 zin *= (src_cross_plane_pad * src0_stride_y);
2796
2797 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2798 // multiply src0_stride_z by DEPTH_GEMM3D
2799 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2800
2801#else // defined(REINTERPRET_INPUT_AS_3D)
2802
2803 // Add offset for batched GEMM
2804 src_addr.s0 += get_global_id(2) * src0_stride_z;
2805
2806#endif // defined(REINTERPRET_INPUT_AS_3D)
2807
2808#if defined(MATRIX_B_DEPTH)
2809 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2810 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2811#else // defined(MATRIX_B_DEPTH)
2812 src_addr.s1 += get_global_id(2) * src1_stride_z;
2813#endif // defined(MATRIX_B_DEPTH)
2814
2815 float8 acc0 = 0.0h;
2816#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2817 float8 acc1 = 0.0h;
2818#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2819#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2820 float8 acc2 = 0.0h;
2821#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2822#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2823 float8 acc3 = 0.0h;
2824#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2825
2826 int i = 0;
2827 for(; i <= ((int)COLS_A - 4); i += 4)
2828 {
2829#if defined(REINTERPRET_INPUT_AS_3D)
2830 // Load values from matrix A
2831 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2832#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2833 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2834#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2835#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2836 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2837#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2838#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2839 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2840#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2841#else // defined(REINTERPRET_INPUT_AS_3D)
2842 // Load values from matrix A
2843 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2844#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2845 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2846#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2847#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2848 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2849#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2850#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2851 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2852#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2853#endif // defined(REINTERPRET_INPUT_AS_3D)
2854
2855 // Load values from matrix B
2856 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2857 src_addr.s1 += src1_stride_y;
2858
2859 // Accumulate
2860 acc0 = fma(b0, (float8)a0.s0, acc0);
2861#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2862 acc1 = fma(b0, (float8)a1.s0, acc1);
2863#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2864#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2865 acc2 = fma(b0, (float8)a2.s0, acc2);
2866#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2867#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2868 acc3 = fma(b0, (float8)a3.s0, acc3);
2869#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2870
2871 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2872 src_addr.s1 += src1_stride_y;
2873 acc0 = fma(b0, (float8)a0.s1, acc0);
2874#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2875 acc1 = fma(b0, (float8)a1.s1, acc1);
2876#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2877#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2878 acc2 = fma(b0, (float8)a2.s1, acc2);
2879#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2880#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2881 acc3 = fma(b0, (float8)a3.s1, acc3);
2882#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2883
2884 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2885 src_addr.s1 += src1_stride_y;
2886 acc0 = fma(b0, (float8)a0.s2, acc0);
2887#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2888 acc1 = fma(b0, (float8)a1.s2, acc1);
2889#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2890#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2891 acc2 = fma(b0, (float8)a2.s2, acc2);
2892#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2893#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2894 acc3 = fma(b0, (float8)a3.s2, acc3);
2895#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2896
2897 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2898 src_addr.s1 += src1_stride_y;
2899 acc0 = fma(b0, (float8)a0.s3, acc0);
2900#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2901 acc1 = fma(b0, (float8)a1.s3, acc1);
2902#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2903#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2904 acc2 = fma(b0, (float8)a2.s3, acc2);
2905#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2906#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2907 acc3 = fma(b0, (float8)a3.s3, acc3);
2908#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2909
2910 src_addr.s0 += 4 * sizeof(half);
2911 }
2912
2913 for(; i < (int)COLS_A; ++i)
2914 {
2915#if defined(REINTERPRET_INPUT_AS_3D)
2916 // Load values from matrix A
2917 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2918#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2919 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2920#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2921#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2922 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2923#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2924#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2925 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2926#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2927#else // defined(REINTERPRET_INPUT_AS_3D)
2928 // Load values from matrix A
2929 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2930#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2931 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2932#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2933#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2934 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2935#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2936#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2937 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2938#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2939#endif // defined(REINTERPRET_INPUT_AS_3D)
2940
2941 // Load values from matrix B
2942 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2943
2944 src_addr += (int2)(sizeof(half), src1_stride_y);
2945
2946 // Accumulate
2947 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
2948#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2949 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
2950#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2951#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2952 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
2953#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2954#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2955 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
2956#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2957 }
2958
2959 // Multiply by the weight of matrix-matrix product and store the result
2960#if defined(ALPHA)
2961 half8 hacc0 = convert_half8(acc0) * (half8)ALPHA;
2962#else //defined(ALPHA)
2963 half8 hacc0 = convert_half8(acc0);
2964#endif // defined(ALPHA)
2965#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
2966#if defined(ALPHA)
2967 half8 hacc1 = convert_half8(acc1) * (half8)ALPHA;
2968#else //defined(ALPHA)
2969 half8 hacc1 = convert_half8(acc1);
2970#endif //defined(ALPHA)
2971#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y
2972
2973#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2974#if defined(ALPHA)
2975 half8 hacc2 = convert_half8(acc2) * (half8)ALPHA;
2976#else //defined(ALPHA)
2977 half8 hacc2 = convert_half8(acc2);
2978#endif //defined(ALPHA)
2979#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
2980
2981#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2982#if defined(ALPHA)
2983 half8 hacc3 = convert_half8(acc3) * (half8)ALPHA;
2984#else //defined(ALPHA)
2985 half8 hacc3 = convert_half8(acc3);
2986#endif // defined(ALPHA)
2987#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
2988
2989 int z = get_global_id(2);
2990
2991 // Compute destination address
2992 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
2993
2994 // Compute dst address
2995 __global uchar *dst_addr = offset(&dst, 0, 0);
2996
2997#if defined(REINTERPRET_OUTPUT_AS_3D)
2998 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2999 // in order to take into account the presence of possible cross plane paddings
3000 //
3001 // | |
3002 // | plane0 |
3003 // | |
3004 // |__________________|
3005 // |******************|
3006 // | cross_plane_pad |
3007 // |******************|
3008 // | |
3009 // | plane1 |
3010 // | |
3011 // |__________________|
3012
3013 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3014 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3015 zout = min(DEPTH_GEMM3D - 1, zout);
3016
3017 // Add offset due to the cross plane paddings
3018 zout *= (dst_cross_plane_pad * dst_stride_y);
3019
3020 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3021 // multiply dst_stride_z by DEPTH_GEMM3D
3022 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3023
3024 // Store the output block
3025 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3026#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3027 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3028#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3029#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3030 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3031#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3032#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3033 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3034#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3035
3036#else // defined(REINTERPRET_OUTPUT_AS_3D)
3037 // Add offset for batched GEMM
3038 dst_addr += z * dst_stride_z;
3039
3040 // Store the output block
3041 vstore8(hacc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3042#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3043 vstore8(hacc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3044#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3045#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3046 vstore8(hacc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3047#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3048#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3049 vstore8(hacc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3050#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3051#endif // REINTERPRET_OUTPUT_AS_3D
3052}
3053
3054/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
3055 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003056 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
3057 * @note The number of elements processed along the x and y directions must be passed at compile time using -DNUM_ELEMS_PROCESSED_PER_THREAD_X and -DNUM_ELEMS_PROCESSED_PER_THREAD_Y.
3058 * This kernel optimally uses -DNUM_ELEMS_PROCESSED_PER_THREAD_X=4.
3059 * @note The number of matrix A columns must be passed at compile time using -DCOLS_A.
3060 * @note The optional value of scalar alpha is passed at compile time using -DALPHA=alpha
3061 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (i.e. -DMATRIX_B_DEPTH=16)
3062 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (i.e. a = [K, M, 16, Batches], b = [N, K, 16])
3063 *
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003064 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
3065 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003066 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
3067 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
3068 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
3069 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
3070 *
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003071 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
3072 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3073 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3074 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3075 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3076 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
3077 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
3078 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3079 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3080 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3081 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3082 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
3083 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
3084 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3085 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3086 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3087 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3088 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003089 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
3090 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3091 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003092 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
3093 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003094 */
3095__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
3096 IMAGE_DECLARATION(src1),
3097 IMAGE_DECLARATION(dst),
3098 uint src0_stride_z,
3099 uint src1_stride_z,
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003100 uint dst_stride_z
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003101#if defined(REINTERPRET_INPUT_AS_3D)
3102 ,
3103 uint src_cross_plane_pad
3104#endif // REINTERPRET_INPUT_AS_3D
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003105#if defined(REINTERPRET_OUTPUT_AS_3D)
3106 ,
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003107 uint dst_cross_plane_pad
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003108#endif // REINTERPRET_OUTPUT_AS_3D
3109 )
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003110{
3111 int idx = get_global_id(0) * NUM_ELEMS_PROCESSED_PER_THREAD_X;
3112
3113 // Compute starting address for matrix A and Matrix B
3114 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
3115
3116 // Update address for the matrix A
3117 src_addr.s0 += get_global_id(1) * src0_stride_y * NUM_ELEMS_PROCESSED_PER_THREAD_Y;
3118
3119 // Update address for the matrix B
3120 src_addr.s1 += idx * sizeof(half);
3121
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003122#if defined(REINTERPRET_INPUT_AS_3D)
3123 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
3124 // in order to take into account the presence of possible cross plane paddings
3125 //
3126 // | |
3127 // | plane0 |
3128 // | |
3129 // |__________________|
3130 // |******************|
3131 // | cross_plane_pad |
3132 // |******************|
3133 // | |
3134 // | plane1 |
3135 // | |
3136 // |__________________|
3137
3138 // The plane (zin) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3139 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3140 zin = min(DEPTH_GEMM3D - 1, zin);
3141
3142 // Add offset due to the cross plane paddings
3143 zin *= (src_cross_plane_pad * src0_stride_y);
3144
3145 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3146 // multiply src0_stride_z by DEPTH_GEMM3D
3147 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3148
3149#else // defined(REINTERPRET_INPUT_AS_3D)
3150
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003151 // Add offset for batched GEMM
3152 src_addr.s0 += get_global_id(2) * src0_stride_z;
3153
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003154#endif // defined(REINTERPRET_INPUT_AS_3D)
3155
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003156#if defined(MATRIX_B_DEPTH)
3157 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3158 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3159#else // defined(MATRIX_B_DEPTH)
3160 src_addr.s1 += get_global_id(2) * src1_stride_z;
3161#endif // defined(MATRIX_B_DEPTH)
3162
3163 half8 acc0 = 0.0h;
3164#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3165 half8 acc1 = 0.0h;
3166#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3167#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3168 half8 acc2 = 0.0h;
3169#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3170#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3171 half8 acc3 = 0.0h;
3172#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3173
3174 int i = 0;
3175 for(; i <= ((int)COLS_A - 4); i += 4)
3176 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003177#if defined(REINTERPRET_INPUT_AS_3D)
3178 // Load values from matrix A
3179 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3180#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3181 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3182#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3183#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3184 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3185#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3186#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3187 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3188#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3189#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003190 // Load values from matrix A
3191 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3192#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3193 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3194#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3195#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3196 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3197#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3198#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3199 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3200#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003201#endif // defined(REINTERPRET_INPUT_AS_3D)
3202
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003203 // Load values from matrix B
3204 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3205 src_addr.s1 += src1_stride_y;
3206
3207 // Accumulate
3208 acc0 = fma(b0, (half8)a0.s0, acc0);
3209#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3210 acc1 = fma(b0, (half8)a1.s0, acc1);
3211#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3212#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3213 acc2 = fma(b0, (half8)a2.s0, acc2);
3214#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3215#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3216 acc3 = fma(b0, (half8)a3.s0, acc3);
3217#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3218
3219 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3220 src_addr.s1 += src1_stride_y;
3221 acc0 = fma(b0, (half8)a0.s1, acc0);
3222#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3223 acc1 = fma(b0, (half8)a1.s1, acc1);
3224#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3225#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3226 acc2 = fma(b0, (half8)a2.s1, acc2);
3227#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3228#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3229 acc3 = fma(b0, (half8)a3.s1, acc3);
3230#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3231
3232 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3233 src_addr.s1 += src1_stride_y;
3234 acc0 = fma(b0, (half8)a0.s2, acc0);
3235#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3236 acc1 = fma(b0, (half8)a1.s2, acc1);
3237#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3238#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3239 acc2 = fma(b0, (half8)a2.s2, acc2);
3240#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3241#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3242 acc3 = fma(b0, (half8)a3.s2, acc3);
3243#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3244
3245 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3246 src_addr.s1 += src1_stride_y;
3247 acc0 = fma(b0, (half8)a0.s3, acc0);
3248#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3249 acc1 = fma(b0, (half8)a1.s3, acc1);
3250#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3251#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3252 acc2 = fma(b0, (half8)a2.s3, acc2);
3253#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3254#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3255 acc3 = fma(b0, (half8)a3.s3, acc3);
3256#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3257
3258 src_addr.s0 += 4 * sizeof(half);
3259 }
3260
3261 for(; i < (int)COLS_A; ++i)
3262 {
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003263#if defined(REINTERPRET_INPUT_AS_3D)
3264 // Load values from matrix A
3265 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3266#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3267 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3268#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3269#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3270 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3271#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3272#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3273 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3274#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3275#else // defined(REINTERPRET_INPUT_AS_3D)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003276 // Load values from matrix A
3277 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3278#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3279 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3280#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3281#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3282 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3283#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3284#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3285 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3286#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003287#endif // defined(REINTERPRET_INPUT_AS_3D)
3288
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003289 // Load values from matrix B
3290 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3291
3292 src_addr += (int2)(sizeof(half), src1_stride_y);
3293
3294 // Accumulate
3295 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
3296#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3297 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
3298#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3299#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3300 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
3301#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3302#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3303 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
3304#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3305 }
3306
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003307 // Multiply by the weight of matrix-matrix product and store the result
3308#if defined(ALPHA)
3309 acc0 = acc0 * (half8)ALPHA;
3310#endif // defined(ALPHA)
3311#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3312 acc1 = acc1 * (half8)ALPHA;
3313#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 && defined(ALPHA)
3314#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3315 acc2 = acc2 * (half8)ALPHA;
3316#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 && defined(ALPHA)
3317#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3318 acc3 = acc3 * (half8)ALPHA;
3319#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 && defined(ALPHA)
3320
3321 int z = get_global_id(2);
3322
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003323 // Compute destination address
3324 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3325
3326 // Compute dst address
3327 __global uchar *dst_addr = offset(&dst, 0, 0);
3328
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003329#if defined(REINTERPRET_OUTPUT_AS_3D)
3330 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003331 // in order to take into account the presence of possible cross plane paddings
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003332 //
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003333 // | |
3334 // | plane0 |
3335 // | |
3336 // |__________________|
3337 // |******************|
3338 // | cross_plane_pad |
3339 // |******************|
3340 // | |
3341 // | plane1 |
3342 // | |
3343 // |__________________|
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003344
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003345 // The plane (zout) is calculated dividing M (get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y) by HEIGHT_GEMM3D
3346 uint4 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * NUM_ELEMS_PROCESSED_PER_THREAD_Y)) / (uint4)HEIGHT_GEMM3D;
3347 zout = min(DEPTH_GEMM3D - 1, zout);
3348
Georgios Pinitase8bd2c72018-07-11 15:54:56 +01003349 // Add offset due to the cross plane paddings
Gian Marco Iodice68a3f562018-07-26 11:44:03 +01003350 zout *= (dst_cross_plane_pad * dst_stride_y);
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003351
3352 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3353 // multiply dst_stride_z by DEPTH_GEMM3D
3354 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3355
3356 // Store the output block
3357 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y + zout.s0));
3358#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3359 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y + zout.s1));
3360#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3361#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3362 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y + zout.s2));
3363#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3364#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3365 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y + zout.s3));
3366#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
3367
3368#else // defined(REINTERPRET_OUTPUT_AS_3D)
3369 // Add offset for batched GEMM
3370 dst_addr += z * dst_stride_z;
3371
3372 // Store the output block
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003373 vstore8(acc0, 0, (__global half *)(dst_addr + 0 * dst_stride_y));
3374#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003375 vstore8(acc1, 0, (__global half *)(dst_addr + 1 * dst_stride_y));
3376#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
3377#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003378 vstore8(acc2, 0, (__global half *)(dst_addr + 2 * dst_stride_y));
3379#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
3380#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003381 vstore8(acc3, 0, (__global half *)(dst_addr + 3 * dst_stride_y));
3382#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003383#endif // REINTERPRET_OUTPUT_AS_3D
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003384}
Vidhya Sudhan Loganathanbdff4912018-05-22 15:03:09 +01003385#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Gian Marco Iodicefd683112018-04-17 09:52:44 +01003386
Gian Marco Iodiceedfa9f42017-08-15 11:45:22 +01003387#endif // defined(COLS_A) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && (NUM_ELEMS_PROCESSED_PER_THREAD_Y)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003388
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003389#if defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003390/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
3391 *
Gian Marco19835e52018-01-30 13:35:54 +00003392 * @note The beta's value need to be passed at compile time using -DBETA
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003393 *
3394 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F32
3395 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
3396 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3397 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
3398 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003399 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
3400 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003401 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003402 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003403 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3404 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3405 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3406 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003407 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3408 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003409 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3410 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003411__kernel void gemm_ma_f32(TENSOR3D_DECLARATION(src),
3412 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003413{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003414 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003415 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3416 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003417
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003418 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003419 float4 alpha_ab = vload4(0, (__global float *)dst.ptr);
3420
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003421 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003422 float4 c = vload4(0, (__global float *)src.ptr);
3423
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003424 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003425 float4 out = alpha_ab + (float4)BETA * c;
3426
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003427 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003428 vstore4(out, 0, (__global float *)dst.ptr);
3429}
3430
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01003431#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003432/** This OpenCL kernel performs the in-place matrix addition between 2 matrices taking into account that the second matrix might be weighted by a scalar value beta:
3433 *
Gian Marco19835e52018-01-30 13:35:54 +00003434 * @note The beta's value need to be passed at compile time using -DBETA
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003435 *
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003436 * @param[in] src_ptr Pointer to the source matrix. Supported data types: F16
3437 * @param[in] src_stride_x Stride of the source matrix in X dimension (in bytes)
3438 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3439 * @param[in] src_stride_y Stride of the source matrix in Y dimension (in bytes)
3440 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003441 * @param[in] src_stride_z Stride of the destination tensor in Z dimension (in bytes)
3442 * @param[in] src_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003443 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003444 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003445 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3446 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3447 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3448 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003449 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
3450 * @param[in] dst_step_z dst_stride_z * number of elements along Z processed per workitem(in bytes)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003451 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3452 */
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003453__kernel void gemm_ma_f16(TENSOR3D_DECLARATION(src),
3454 TENSOR3D_DECLARATION(dst))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003455{
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003456 // Compute source and destination addresses
Isabella Gottardi8e74f442018-03-01 16:42:00 +00003457 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3458 Tensor3D dst = CONVERT_TO_TENSOR3D_STRUCT(dst);
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003459
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003460 // Load values from A x B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003461 half8 alpha_ab = vload8(0, (__global half *)dst.ptr);
3462
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003463 // Load values from Matrix C
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003464 half8 c = vload8(0, (__global half *)src.ptr);
3465
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003466 // Computes alpha * axb + beta * c
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003467 half8 out = alpha_ab + (half8)BETA * c;
3468
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003469 // Store final result in axb matrix
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003470 vstore8(out, 0, (__global half *)dst.ptr);
3471}
Vidhya Sudhan Loganathan76c85642018-05-25 13:53:02 +01003472#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003473#endif // defined(BETA)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003474
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003475#if defined(WIDTH_VECTOR_A)
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003476/** This OpenCL kernel computes the vector by matrix multiplication between each row of A (src0) and matrix B (src1) used for locally connected layer
3477 *
Gian Marco19835e52018-01-30 13:35:54 +00003478 * @note The width of A need to be passed at compile time using -DWIDTH_VECTOR_A
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003479 *
Gian Marco19835e52018-01-30 13:35:54 +00003480 * @note The input A and matrix B must not be reshaped
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003481 *
3482 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
3483 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
3484 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3485 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
3486 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3487 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003488 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003489 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
3490 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3491 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
3492 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3493 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
3494 * @param[in] src1_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3495 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
Gian Marco Iodice3a3066b2017-06-23 13:38:14 +01003496 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003497 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
3498 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
3499 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
3500 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
3501 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
3502 */
3503__kernel void gemm_lc_vm_f32(IMAGE_DECLARATION(src0),
3504 TENSOR3D_DECLARATION(src1),
3505 IMAGE_DECLARATION(dst))
3506{
3507 int idx = get_global_id(0) * 4;
3508 int idy = get_global_id(1);
3509
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003510 // Compute the address for the vector A and matrix B
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003511 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes + src0_stride_y * idy, src1_offset_first_element_in_bytes + src1_stride_z * idy));
3512 src_addr.s1 += idx * sizeof(float);
3513
3514 int end_row_vec_a = src_addr.s0 + (WIDTH_VECTOR_A * sizeof(float));
3515
3516 float4 acc = 0.0f;
3517
Georgios Pinitas96880cf2017-10-20 18:52:20 +01003518 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(float)); src_addr += (int2)(2 * sizeof(float), 2 * src1_stride_y))
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003519 {
3520 float2 a0 = vload2(0, (__global float *)(src0_ptr + src_addr.s0));
3521 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3522 float4 b1 = vload4(0, (__global float *)(src1_ptr + src_addr.s1 + src1_stride_y));
3523
3524 acc += b0 * (float4)a0.s0;
3525 acc += b1 * (float4)a0.s1;
3526 }
3527
3528 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(float), src1_stride_y))
3529 {
3530 float a0 = *((__global float *)(src0_ptr + src_addr.s0));
3531 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
3532
3533 acc += b0 * (float4)a0;
3534 }
3535
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003536 // Compute destination address
Anthony Barbier6ff3b192017-09-04 18:44:23 +01003537 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
3538
3539 vstore4(acc, 0, (__global float *)(offset(&dst, 0, 0)));
3540}
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003541#endif // defined(WIDTH_VECTOR_A)
3542
3543/** This kernel accumulates each row with the biases vector.
3544 *
3545 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=short.
3546 * @note The vector size must be passed at compile time using -DVECTOR_SIZE e.g. -DVECTOR_SIZE=16.
3547 *
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01003548 * @param[in, out] accum_ptr Pointer to the accumulate tensor. Supported data type: U8/S8/U16/S16/F16/U32/S32/F32
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003549 * @param[in] accum_stride_x Stride of the accmulate tensor in X dimension (in bytes)
3550 * @param[in] accum_step_x accum_stride_x * number of elements along X processed per workitem(in bytes)
3551 * @param[in] accum_stride_y Stride of the accumlulate tensor in Y dimension (in bytes)
3552 * @param[in] accum_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3553 * @param[in] accum_offset_first_element_in_bytes The offset of the first element in the accumulate tensor
3554 * @param[in] biases_ptr Pointer to the biases vector. Same as @p accum_ptr
3555 * @param[in] biases_stride_x Stride of the destination tensor in X dimension (in bytes)
3556 * @param[in] biases_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3557 * @param[in] biases_offset_first_element_in_bytes The offset of the first element in the destination tensor
3558 */
3559#if defined(DATA_TYPE) && defined(VECTOR_SIZE)
3560__kernel void gemm_accumulate_biases(
3561 IMAGE_DECLARATION(accum),
3562 VECTOR_DECLARATION(biases))
3563{
3564 Image accum = CONVERT_TO_IMAGE_STRUCT(accum);
3565 Vector biases = CONVERT_TO_VECTOR_STRUCT(biases);
3566
3567 // Vector size, i.e. number of vector elements.
3568 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
3569 accum_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)accum.ptr);
3570 VEC_DATA_TYPE(DATA_TYPE, VECTOR_SIZE)
3571 biases_value = VLOAD(VECTOR_SIZE)(0, (__global DATA_TYPE *)biases.ptr);
Vidhya Sudhan Loganathan7485d5a2018-07-04 09:34:00 +01003572 accum_value = biases_value + accum_value;
Anton Lokhmotov3e80c7f2017-11-20 11:02:10 +00003573 // Store result in the accumulate buffer
3574 VSTORE(VECTOR_SIZE)
3575 (accum_value, 0, (__global DATA_TYPE *)accum.ptr);
3576}
3577#endif // defined(DATA_TYPE) && defined(VECTOR_SIZE)