blob: a136a1b96bb1e477400d81fb371f81a1bd8e3e7a [file] [log] [blame]
SiCong Li4abc9d12020-10-28 14:19:28 +00001/*
Manuel Bottini99b1a1c2021-03-29 17:15:21 +01002 * Copyright (c) 2020-2021 Arm Limited.
SiCong Li4abc9d12020-10-28 14:19:28 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "gemm_helpers.h"
25#include "repeat.h"
26
Manuel Bottini99b1a1c2021-03-29 17:15:21 +010027#if defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) && defined(IN1_DIM_X)
SiCong Li4abc9d12020-10-28 14:19:28 +000028/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
29 *
SiCong Li0ea50e32020-11-05 09:18:11 +000030 * @note The number of rows of destination matrix must be passed at compile time using -DM
31 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +000032 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
Manuel Bottini99b1a1c2021-03-29 17:15:21 +010033 * @note The number of columns of the reshaped rhs matrix must be passed at compile time using -DIN1_DIM_X
SiCong Li0ea50e32020-11-05 09:18:11 +000034 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
35 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +000036 * @note The optional alpha's value need to be passed at compile time using -DALPHA
37 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
38 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
39 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
40 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
41 *
42 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
43 * The activation function is performed after the bias addition
44 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
45 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
46 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
47 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
48 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
49 *
50 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
51 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
52 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
53 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
54 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
55 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
56 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
57 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
58 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
59 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
60 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
61 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
62 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
63 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
64 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
65 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
66 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
67 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
68 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
69 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
70 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
71 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
72 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
73 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
74 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
75 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
76 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
77 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
78 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
79 */
80__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
81 IMAGE_DECLARATION(src1),
82#if defined(BETA)
83 IMAGE_DECLARATION(src2),
84#endif // defined(BETA)
85 IMAGE_DECLARATION(dst),
86 uint src0_stride_z,
87 uint src1_stride_z,
88#if defined(BETA)
89 uint src2_stride_z,
90#endif //defined(BETA)
91 uint dst_stride_z
92#if defined(REINTERPRET_OUTPUT_AS_3D)
93 ,
94 uint cross_plane_pad
95#endif // REINTERPRET_OUTPUT_AS_3D
96 )
97{
98 int x = get_global_id(0) / H0;
99 int y = get_global_id(1) / V0;
100 int z = get_global_id(2);
101
102 // Offset
103 const int offset_row_a = (get_global_id(1) % V0) * 4;
104 const int offset_row_b = (get_global_id(0) % H0) * 4;
105
106 // src_addr_a = address of matrix A
107 // src_addr_b = address of matrix B
108 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
109 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
110
111#if defined(MATRIX_B_DEPTH)
112 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
113 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
114#else // defined(MATRIX_B_DEPTH)
115 src1_addr_in_bytes += z * src1_stride_z;
116#endif // defined(MATRIX_B_DEPTH)
117
118 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
119 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
120
121 // Compute end row address for matrix B
Manuel Bottini99b1a1c2021-03-29 17:15:21 +0100122 __global float *src_end_addr_b = src_addr_b + IN1_DIM_X;
SiCong Li4abc9d12020-10-28 14:19:28 +0000123
124 src_addr_a += offset_row_a;
125 src_addr_b += offset_row_b;
126
127 // Reset accumulators
128 float4 c0 = 0.0f;
129 float4 c1 = 0.0f;
130 float4 c2 = 0.0f;
131 float4 c3 = 0.0f;
132
133 for(; src_addr_b <= (src_end_addr_b - (int)(8 * H0)); src_addr_a += 8 * V0, src_addr_b += 8 * H0)
134 {
135 // Load values from matrix A (interleaved) and matrix B (transposed)
136 float4 a0 = vload4(0, src_addr_a);
137 float4 b0 = vload4(0, src_addr_b);
138
139 c0 += (float4)a0.s0 * b0;
140 c1 += (float4)a0.s1 * b0;
141 c2 += (float4)a0.s2 * b0;
142 c3 += (float4)a0.s3 * b0;
143
144 // Load values from matrix A (interleaved) and matrix B (transposed)
145 a0 = vload4(0, src_addr_a + 4 * V0);
146 b0 = vload4(0, src_addr_b + 4 * H0);
147
148 c0 += (float4)a0.s0 * b0;
149 c1 += (float4)a0.s1 * b0;
150 c2 += (float4)a0.s2 * b0;
151 c3 += (float4)a0.s3 * b0;
152 }
153
154 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 4 * H0)
155 {
156 // Load values from matrix A (interleaved) and matrix B (transposed)
157 float4 a0 = vload4(0, src_addr_a);
158 float4 b0 = vload4(0, src_addr_b);
159
160 c0 += (float4)a0.s0 * b0;
161 c1 += (float4)a0.s1 * b0;
162 c2 += (float4)a0.s2 * b0;
163 c3 += (float4)a0.s3 * b0;
164 }
165
166 // Compute destination address
167 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
168
169 // Compute dst address
170 __global uchar *dst_addr = offset(&dst, 0, 0);
171
172 uint4 zout = 0;
173
174#if defined(REINTERPRET_OUTPUT_AS_3D)
175 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
176 // in order to take into account the presence of possible cross plane paddings
177 //
178 // | |
179 // | plane0 |
180 // | |
181 // |__________________|
182 // |******************|
183 // | cross_plane_pad |
184 // |******************|
185 // | |
186 // | plane1 |
187 // | |
188 // |__________________|
189
190 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
191 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
192 zout = min(DEPTH_GEMM3D - 1, zout);
193
194 // Add offset due to the cross plane paddings
195 zout *= (cross_plane_pad * dst_stride_y);
196
197 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
198 // multiply dst_stride_z by DEPTH_GEMM3D
199 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
200#else // defined(REINTERPRET_OUTPUT_AS_3D)
201 // Add offset for batched GEMM
202 dst_addr += z * dst_stride_z;
203#endif // defined(REINTERPRET_OUTPUT_AS_3D)
204
205 // Multiply by the weight of matrix-matrix product and store the result
206#if defined(ALPHA)
207 SCALE_BLOCK(4, float, c, ALPHA);
208#endif // defined(ALPHA)
209
210 // Add beta*bias
211#if defined(BETA)
212 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
213
214#if defined(BROADCAST_BIAS)
215 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
216
217 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
218
219#ifndef UNIT_BETA
220 SCALE_BLOCK(1, float, bias, BETA);
221#endif // UNIT_BIAS
222
223 // c = c + bias[broadcasted]
224 ADD_BLOCK_BROADCAST(4, c, bias0);
225
226#else // defined(BROADCAST_BIAS)
227 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
228 2) * src2_stride_z;
229
230 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
231
232#ifndef UNIT_BETA
233 SCALE_BLOCK(4, float, bias, BETA);
234#endif // UNIT_BIAS
235
236 // c = c + bias
237 ADD_BLOCK(4, c, bias);
238
239#endif // defined(BROADCAST_BIAS)
240#endif // defined(BETA)
241
242#if defined(ACTIVATION_TYPE)
243 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, VEC_SIZE, c, A_VAL, B_VAL);
244#endif // defined(ACTIVATION_TYPE)
245
246 // Store 4x4 block
SiCong Li0ea50e32020-11-05 09:18:11 +0000247 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
248 const bool cond_x = ((get_global_id(0) + 1) * 4 >= N);
249 STORE_BLOCK_BOUNDARY_AWARE(4, 4, float, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +0000250}
251
252/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
253 *
SiCong Li0ea50e32020-11-05 09:18:11 +0000254 * @note The number of rows of destination matrix must be passed at compile time using -DM
255 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +0000256 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
SiCong Li0ea50e32020-11-05 09:18:11 +0000257 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
258 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +0000259 * @note The optional alpha's value need to be passed at compile time using -DALPHA
260 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
261 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
SiCong Li4abc9d12020-10-28 14:19:28 +0000262 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
263 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
264 *
265 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
266 * The activation function is performed after the bias addition
267 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
268 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
269 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
270 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
271 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
272 *
273 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
274 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
275 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
276 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
277 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
278 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
279 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
280 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
281 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
282 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
283 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
284 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
285 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
286 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
287 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
288 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
289 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
290 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
291 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
292 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
293 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
294 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
295 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
296 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
297 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
298 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
299 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
300 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
301 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
302 */
303__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
304 IMAGE_DECLARATION(src1),
305#if defined(BETA)
306 IMAGE_DECLARATION(src2),
307#endif // defined(BETA)
308 IMAGE_DECLARATION(dst),
309 uint src0_stride_z,
310 uint src1_stride_z,
311#if defined(BETA)
312 uint src2_stride_z,
313#endif //defined(BETA)
314 uint dst_stride_z
315#if defined(REINTERPRET_OUTPUT_AS_3D)
316 ,
317 uint cross_plane_pad
318#endif // REINTERPRET_OUTPUT_AS_3D
319 )
320{
321 int x = get_global_id(0) / H0;
322 int y = get_global_id(1) / V0;
323 int z = get_global_id(2);
324
325 // Offset
326 const int offset_row_a = (get_global_id(1) % V0) * 4;
327 const int offset_row_b = (get_global_id(0) % H0) * 4;
328
329 // src_addr_a = address of matrix A
330 // src_addr_b = address of matrix B
331 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
332 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
333
334#if defined(MATRIX_B_DEPTH)
335 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
336 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
337#else // defined(MATRIX_B_DEPTH)
338 src1_addr_in_bytes += z * src1_stride_z;
339#endif // defined(MATRIX_B_DEPTH)
340
341 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
342 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
343
344 src_addr_a += offset_row_a;
345 src_addr_b += offset_row_b;
346
347 // Reset accumulators
348 float4 c0 = 0.0f;
349 float4 c1 = 0.0f;
350 float4 c2 = 0.0f;
351 float4 c3 = 0.0f;
352
353 int i = 0;
354 for(; i <= (int)(K - 4); i += 4)
355 {
356 // Load values from matrix A (interleaved) and matrix B (transposed)
357 float4 a0 = vload4(0, src_addr_a);
358 float4 b0 = vload4(0, src_addr_b);
359
360 src_addr_a += 4 * V0;
361 src_addr_b += 4 * H0;
362
363 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
364 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
365 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
366 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
367
368 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
369 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
370 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
371 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
372
373 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
374 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
375 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
376 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
377
378 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
379 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
380 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
381 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
382
383 // Load values from matrix A (interleaved) and matrix B (transposed)
384 a0 = vload4(0, src_addr_a);
385 b0 = vload4(0, src_addr_b);
386
387 src_addr_a += 4 * V0;
388 src_addr_b += 4 * H0;
389
390 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
391 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
392 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
393 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
394
395 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
396 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
397 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
398 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
399
400 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
401 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
402 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
403 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
404
405 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
406 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
407 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
408 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
409
410 // Load values from matrix A (interleaved) and matrix B (transposed)
411 a0 = vload4(0, src_addr_a);
412 b0 = vload4(0, src_addr_b);
413
414 src_addr_a += 4 * V0;
415 src_addr_b += 4 * H0;
416
417 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
418 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
419 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
420 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
421
422 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
423 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
424 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
425 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
426
427 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
428 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
429 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
430 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
431
432 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
433 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
434 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
435 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
436
437 // Load values from matrix A (interleaved) and matrix B (transposed)
438 a0 = vload4(0, src_addr_a);
439 b0 = vload4(0, src_addr_b);
440
441 src_addr_a += 4 * V0;
442 src_addr_b += 4 * H0;
443
444 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
445 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
446 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
447 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
448
449 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
450 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
451 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
452 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
453
454 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
455 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
456 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
457 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
458
459 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
460 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
461 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
462 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
463 }
464
465 for(; i < (int)K; ++i)
466 {
467 // Load values from matrix A (interleaved) and matrix B (transposed)
468 float4 a0 = vload4(0, src_addr_a);
469 float4 b0 = vload4(0, src_addr_b);
470
471 src_addr_a += 4 * V0;
472 src_addr_b += 4 * H0;
473
474 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
475 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
476 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
477 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
478
479 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
480 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
481 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
482 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
483
484 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
485 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
486 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
487 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
488
489 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
490 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
491 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
492 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
493 }
494
495 // Compute destination address
496 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
497
498 // Compute dst address
499 __global uchar *dst_addr = offset(&dst, 0, 0);
500
501 uint4 zout = 0;
502
503#if defined(REINTERPRET_OUTPUT_AS_3D)
504 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
505 // in order to take into account the presence of possible cross plane paddings
506 //
507 // | |
508 // | plane0 |
509 // | |
510 // |__________________|
511 // |******************|
512 // | cross_plane_pad |
513 // |******************|
514 // | |
515 // | plane1 |
516 // | |
517 // |__________________|
518
519 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
520 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
521 zout = min(DEPTH_GEMM3D - 1, zout);
522
523 // Add offset due to the cross plane paddings
524 zout *= (cross_plane_pad * dst_stride_y);
525
526 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
527 // multiply dst_stride_z by DEPTH_GEMM3D
528 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
529#else // defined(REINTERPRET_OUTPUT_AS_3D)
530 // Add offset for batched GEMM
531 dst_addr += z * dst_stride_z;
532#endif // defined(REINTERPRET_OUTPUT_AS_3D)
533
534 // Multiply by the weight of matrix-matrix product and store the result
535#if defined(ALPHA)
536 SCALE_BLOCK(4, float, c, ALPHA);
537#endif // defined(ALPHA)
538
539 // Add beta*bias
540#if defined(BETA)
541 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
542
543#if defined(BROADCAST_BIAS)
544 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
545
546 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
547
548#ifndef UNIT_BETA
549 SCALE_BLOCK(1, float, bias, BETA);
550#endif // UNIT_BIAS
551
552 // c = c + bias[broadcasted]
553 ADD_BLOCK_BROADCAST(4, c, bias0);
554
555#else // defined(BROADCAST_BIAS)
556 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
557 2) * src2_stride_z;
558
559 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
560
561#ifndef UNIT_BETA
562 SCALE_BLOCK(4, float, bias, BETA);
563#endif // UNIT_BIAS
564
565 // c = c + bias
566 ADD_BLOCK(4, c, bias);
567
568#endif // defined(BROADCAST_BIAS)
569#endif // defined(BETA)
570
571#if defined(ACTIVATION_TYPE)
572 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, VEC_SIZE, c, A_VAL, B_VAL);
573#endif // defined(ACTIVATION_TYPE)
574
575 // Store 4x4 block
SiCong Li0ea50e32020-11-05 09:18:11 +0000576 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
577 const bool cond_x = ((get_global_id(0) + 1) * 4 >= N);
578 STORE_BLOCK_BOUNDARY_AWARE(4, 4, float, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +0000579}
580
581#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
582/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
583 *
SiCong Li0ea50e32020-11-05 09:18:11 +0000584 * @note The number of rows of destination matrix must be passed at compile time using -DM
585 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +0000586 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
Manuel Bottini99b1a1c2021-03-29 17:15:21 +0100587 * @note The number of columns of the reshaped rhs matrix must be passed at compile time using -DIN1_DIM_X
SiCong Li0ea50e32020-11-05 09:18:11 +0000588 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
589 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +0000590 * @note The optional alpha's value need to be passed at compile time using -DALPHA
591 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
592 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
593 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
594 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
595 *
596 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
597 * The activation function is performed after the bias addition
598 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
599 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
600 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
601 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
602 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
603 *
604 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
605 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
606 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
607 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
608 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
609 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
610 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
611 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
612 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
613 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
614 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
615 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
616 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
617 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
618 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
619 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
620 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
621 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
622 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
623 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
624 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
625 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
626 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
627 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
628 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
629 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
630 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
631 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
632 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
633 */
634__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
635 IMAGE_DECLARATION(src1),
636#if defined(BETA)
637 IMAGE_DECLARATION(src2),
638#endif // defined(BETA)
639 IMAGE_DECLARATION(dst),
640 uint src0_stride_z,
641 uint src1_stride_z,
642#if defined(BETA)
643 uint src2_stride_z,
644#endif //defined(BETA)
645 uint dst_stride_z
646#if defined(REINTERPRET_OUTPUT_AS_3D)
647 ,
648 uint cross_plane_pad
649#endif // REINTERPRET_OUTPUT_AS_3D
650 )
651{
652 int x = get_global_id(0) / H0;
653 int y = get_global_id(1) / V0;
654 int z = get_global_id(2);
655
656 // Offset
657 const int offset_row_a = (get_global_id(1) % V0) * 4;
658 const int offset_row_b = (get_global_id(0) % H0) * 8;
659
660 // src_addr_a = address of matrix A
661 // src_addr_b = address of matrix B
662 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
663 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
664
665#if defined(MATRIX_B_DEPTH)
666 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
667 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
668#else // defined(MATRIX_B_DEPTH)
669 src1_addr_in_bytes += z * src1_stride_z;
670#endif // defined(MATRIX_B_DEPTH)
671
672 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
673 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
674
675 // Compute end row address for matrix B
Manuel Bottini99b1a1c2021-03-29 17:15:21 +0100676 __global half *src_end_addr_b = src_addr_b + IN1_DIM_X;
SiCong Li4abc9d12020-10-28 14:19:28 +0000677
678 src_addr_a += offset_row_a;
679 src_addr_b += offset_row_b;
680
681 // Reset accumulators
682 half8 c0 = 0.0f;
683 half8 c1 = 0.0f;
684 half8 c2 = 0.0f;
685 half8 c3 = 0.0f;
686
687 for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0)
688 {
689 // Load values from matrix A (interleaved) and matrix B (transposed)
690 half4 a0 = vload4(0, src_addr_a);
691 half8 b0 = vload8(0, src_addr_b);
692
693 c0 += (half8)a0.s0 * b0;
694 c1 += (half8)a0.s1 * b0;
695 c2 += (half8)a0.s2 * b0;
696 c3 += (half8)a0.s3 * b0;
697
698 // Load values from matrix A (interleaved) and matrix B (transposed)
699 a0 = vload4(0, src_addr_a + 4 * V0);
700 b0 = vload8(0, src_addr_b + 8 * H0);
701
702 c0 += (half8)a0.s0 * b0;
703 c1 += (half8)a0.s1 * b0;
704 c2 += (half8)a0.s2 * b0;
705 c3 += (half8)a0.s3 * b0;
706 }
707
708 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0)
709 {
710 // Load values from matrix A (interleaved) and matrix B (transposed)
711 half4 a0 = vload4(0, src_addr_a);
712 half8 b0 = vload8(0, src_addr_b);
713
714 c0 += (half8)a0.s0 * b0;
715 c1 += (half8)a0.s1 * b0;
716 c2 += (half8)a0.s2 * b0;
717 c3 += (half8)a0.s3 * b0;
718 }
719
720 // Compute destination address
721 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
722
723 // Compute dst address
724 __global uchar *dst_addr = offset(&dst, 0, 0);
725
726 uint4 zout = 0;
727
728#if defined(REINTERPRET_OUTPUT_AS_3D)
729 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
730 // in order to take into account the presence of possible cross plane paddings
731 //
732 // | |
733 // | plane0 |
734 // | |
735 // |__________________|
736 // |******************|
737 // | cross_plane_pad |
738 // |******************|
739 // | |
740 // | plane1 |
741 // | |
742 // |__________________|
743
744 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
745 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
746 zout = min(DEPTH_GEMM3D - 1, zout);
747
748 // Add offset due to the cross plane paddings
749 zout *= (cross_plane_pad * dst_stride_y);
750
751 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
752 // multiply dst_stride_z by DEPTH_GEMM3D
753 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
754#else // defined(REINTERPRET_OUTPUT_AS_3D)
755 // Add offset for batched GEMM
756 dst_addr += z * dst_stride_z;
757#endif // defined(REINTERPRET_OUTPUT_AS_3D)
758
759 // Multiply by the weight of matrix-matrix product and store the result
760#if defined(ALPHA)
761 SCALE_BLOCK(4, half, c, ALPHA);
762#endif // defined(ALPHA)
763
764 // Add beta*bias
765#if defined(BETA)
766 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
767
768#if defined(BROADCAST_BIAS)
769 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
770
771 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
772
773#ifndef UNIT_BETA
774 SCALE_BLOCK(1, half, bias, BETA);
775#endif // UNIT_BIAS
776
777 // c = c + bias[broadcasted]
778 ADD_BLOCK_BROADCAST(4, c, bias0);
779
780#else // defined(BROADCAST_BIAS)
781
782 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
783 2) * src2_stride_z;
784
785 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
786
787#ifndef UNIT_BETA
788 SCALE_BLOCK(4, half, bias, BETA);
789#endif // UNIT_BIAS
790
791 // c = c + bias
792 ADD_BLOCK(4, c, bias);
793
794#endif // defined(BROADCAST_BIAS)
795#endif // defined(BETA)
796
797#if defined(ACTIVATION_TYPE)
798 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c, A_VAL, B_VAL);
799#endif // defined(ACTIVATION_TYPE)
800
801 // Store 4x8 block
SiCong Li0ea50e32020-11-05 09:18:11 +0000802 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
803 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
804 STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +0000805}
806
807/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
808 *
SiCong Li0ea50e32020-11-05 09:18:11 +0000809 * @note The number of rows of destination matrix must be passed at compile time using -DM
810 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +0000811 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
Manuel Bottini99b1a1c2021-03-29 17:15:21 +0100812 * @note The number of columns of the reshaped rhs matrix must be passed at compile time using -DIN1_DIM_X
SiCong Li0ea50e32020-11-05 09:18:11 +0000813 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
814 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +0000815 * @note The optional alpha's value need to be passed at compile time using -DALPHA
816 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
817 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
818 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
819 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
820 *
821 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
822 * The activation function is performed after the bias addition
823 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
824 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
825 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
826 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
827 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
828 *
829 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
830 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
831 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
832 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
833 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
834 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
835 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
836 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
837 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
838 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
839 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
840 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
841 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
842 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
843 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
844 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
845 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
846 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
847 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
848 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
849 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
850 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
851 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
852 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
853 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
854 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
855 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
856 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
857 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
858 */
859__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
860 IMAGE_DECLARATION(src1),
861#if defined(BETA)
862 IMAGE_DECLARATION(src2),
863#endif // defined(BETA)
864 IMAGE_DECLARATION(dst),
865 uint src0_stride_z,
866 uint src1_stride_z,
867#if defined(BETA)
868 uint src2_stride_z,
869#endif //defined(BETA)
870 uint dst_stride_z
871#if defined(REINTERPRET_OUTPUT_AS_3D)
872 ,
873 uint cross_plane_pad
874#endif // REINTERPRET_OUTPUT_AS_3D
875 )
876{
877 int x = get_global_id(0) / H0;
878 int y = get_global_id(1) / V0;
879 int z = get_global_id(2);
880
881 // Offset
882 const int offset_row_a = (get_global_id(1) % V0) * 4;
883 const int offset_row_b = (get_global_id(0) % H0) * 8;
884
885 // src_addr_a = address of matrix A
886 // src_addr_b = address of matrix B
887 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
888 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
889
890#if defined(MATRIX_B_DEPTH)
891 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
892 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
893#else // defined(MATRIX_B_DEPTH)
894 src1_addr_in_bytes += z * src1_stride_z;
895#endif // defined(MATRIX_B_DEPTH)
896
897 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
898 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
899
900 // Compute end row address for matrix B
Manuel Bottini99b1a1c2021-03-29 17:15:21 +0100901 __global half *src_end_addr_b = src_addr_b + IN1_DIM_X;
SiCong Li4abc9d12020-10-28 14:19:28 +0000902
903 src_addr_a += offset_row_a;
904 src_addr_b += offset_row_b;
905
906 // Reset accumulators
907 float8 c0 = 0.0f;
908 float8 c1 = 0.0f;
909 float8 c2 = 0.0f;
910 float8 c3 = 0.0f;
911
912 for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0)
913 {
914 // Load values from matrix A (interleaved) and matrix B (transposed)
915 float4 a0 = convert_float4(vload4(0, src_addr_a));
916 float8 b0 = convert_float8(vload8(0, src_addr_b));
917
918 c0 += (float8)a0.s0 * b0;
919 c1 += (float8)a0.s1 * b0;
920 c2 += (float8)a0.s2 * b0;
921 c3 += (float8)a0.s3 * b0;
922
923 // Load values from matrix A (interleaved) and matrix B (transposed)
924 a0 = convert_float4(vload4(0, src_addr_a + 4 * V0));
925 b0 = convert_float8(vload8(0, src_addr_b + 8 * H0));
926
927 c0 += (float8)a0.s0 * b0;
928 c1 += (float8)a0.s1 * b0;
929 c2 += (float8)a0.s2 * b0;
930 c3 += (float8)a0.s3 * b0;
931 }
932
933 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0)
934 {
935 // Load values from matrix A (interleaved) and matrix B (transposed)
936 float4 a0 = convert_float4(vload4(0, src_addr_a));
937 float8 b0 = convert_float8(vload8(0, src_addr_b));
938
939 c0 += (float8)a0.s0 * b0;
940 c1 += (float8)a0.s1 * b0;
941 c2 += (float8)a0.s2 * b0;
942 c3 += (float8)a0.s3 * b0;
943 }
944
945 // Compute destination address
946 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
947
948 // Compute dst address
949 __global uchar *dst_addr = offset(&dst, 0, 0);
950
951 uint4 zout = 0;
952
953#if defined(REINTERPRET_OUTPUT_AS_3D)
954 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
955 // in order to take into account the presence of possible cross plane paddings
956 //
957 // | |
958 // | plane0 |
959 // | |
960 // |__________________|
961 // |******************|
962 // | cross_plane_pad |
963 // |******************|
964 // | |
965 // | plane1 |
966 // | |
967 // |__________________|
968
969 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
970 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
971 zout = min(DEPTH_GEMM3D - 1, zout);
972
973 // Add offset due to the cross plane paddings
974 zout *= (cross_plane_pad * dst_stride_y);
975
976 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
977 // multiply dst_stride_z by DEPTH_GEMM3D
978 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
979#else // defined(REINTERPRET_OUTPUT_AS_3D)
980 // Add offset for batched GEMM
981 dst_addr += z * dst_stride_z;
982#endif // defined(REINTERPRET_OUTPUT_AS_3D)
983
984 // Multiply by the weight of matrix-matrix product and store the result
985#if defined(ALPHA)
986 SCALE_BLOCK(4, float, c, ALPHA);
987#endif // defined(ALPHA)
988
989#if defined(BETA)
990 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
991
992#if defined(BROADCAST_BIAS)
993 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
994
995 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
996
997 float8 bias_f0 = convert_float8(bias0);
998
999#ifndef UNIT_BETA
1000 SCALE_BLOCK(1, float, bias_f, BETA);
1001#endif // UNIT_BIAS
1002
1003 // c = c + bias[broadcasted]
1004 ADD_BLOCK_BROADCAST(4, c, bias_f0);
1005
1006#else // defined(BROADCAST_BIAS)
1007 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
1008 2) * src2_stride_z;
1009
1010 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
1011
1012 float8 bias_f0 = convert_float8(bias0);
1013 float8 bias_f1 = convert_float8(bias1);
1014 float8 bias_f2 = convert_float8(bias2);
1015 float8 bias_f3 = convert_float8(bias3);
1016
1017#ifndef UNIT_BETA
1018 SCALE_BLOCK(4, float, bias_f, BETA);
1019#endif // UNIT_BIAS
1020
1021 // c = c + bias
1022 ADD_BLOCK(4, c, bias_f);
1023
1024#endif // defined(BROADCAST_BIAS)
1025#endif // defined(BETA)
1026
1027 half8 c_h0 = convert_half8(c0);
1028 half8 c_h1 = convert_half8(c1);
1029 half8 c_h2 = convert_half8(c2);
1030 half8 c_h3 = convert_half8(c3);
1031
1032#if defined(ACTIVATION_TYPE)
1033 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c_h, A_VAL, B_VAL);
1034#endif // defined(ACTIVATION_TYPE)
1035
1036 // Store 4x8 block
SiCong Li0ea50e32020-11-05 09:18:11 +00001037 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
1038 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
1039 STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c_h, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00001040}
1041
1042/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
1043 *
SiCong Li0ea50e32020-11-05 09:18:11 +00001044 * @note The number of rows of destination matrix must be passed at compile time using -DM
1045 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +00001046 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
SiCong Li0ea50e32020-11-05 09:18:11 +00001047 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
1048 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +00001049 * @note The optional alpha's value need to be passed at compile time using -DALPHA
1050 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
1051 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
1052 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
1053 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
1054 *
1055 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1056 * The activation function is performed after the bias addition
1057 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
1058 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1059 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1060 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1061 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1062 *
1063 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1064 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1065 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1066 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1067 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1068 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1069 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1070 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1071 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1072 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1073 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1074 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1075 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1076 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1077 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
1078 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1079 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
1080 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1081 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1082 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1083 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1084 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1085 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1086 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1087 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1088 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1089 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1090 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1091 */
1092__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
1093 IMAGE_DECLARATION(src1),
1094#if defined(BETA)
1095 IMAGE_DECLARATION(src2),
1096#endif // defined(BETA)
1097 IMAGE_DECLARATION(dst),
1098 uint src0_stride_z,
1099 uint src1_stride_z,
1100#if defined(BETA)
1101 uint src2_stride_z,
1102#endif //defined(BETA)
1103 uint dst_stride_z
1104#if defined(REINTERPRET_OUTPUT_AS_3D)
1105 ,
1106 uint cross_plane_pad
1107#endif // REINTERPRET_OUTPUT_AS_3D
1108 )
1109{
1110 int x = get_global_id(0) / H0;
1111 int y = get_global_id(1) / V0;
1112 int z = get_global_id(2);
1113
1114 // Offset
1115 const int offset_row_a = (get_global_id(1) % V0) * 4;
1116 const int offset_row_b = (get_global_id(0) % H0) * 8;
1117
1118 // src_addr_a = address of matrix A
1119 // src_addr_b = address of matrix B
1120 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1121 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1122
1123#if defined(MATRIX_B_DEPTH)
1124 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1125 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1126#else // defined(MATRIX_B_DEPTH)
1127 src1_addr_in_bytes += z * src1_stride_z;
1128#endif // defined(MATRIX_B_DEPTH)
1129
1130 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
1131 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
1132
1133 src_addr_a += offset_row_a;
1134 src_addr_b += offset_row_b;
1135
1136 // Reset accumulators
1137 half8 c0 = 0.0f;
1138 half8 c1 = 0.0f;
1139 half8 c2 = 0.0f;
1140 half8 c3 = 0.0f;
1141
1142 int i = 0;
1143 for(; i <= (int)(K - 4); i += 4)
1144 {
1145#if V0 == 1
1146 // Load values from matrix A (interleaved) and matrix B (transposed)
1147 half8 a0 = vload8(0, src_addr_a);
1148 half8 b0 = vload8(0, src_addr_b);
1149
1150 src_addr_a += 8 * V0;
1151 src_addr_b += 8 * H0;
1152
1153 c0 = fma((half8)a0.s0, b0, c0);
1154 c1 = fma((half8)a0.s1, b0, c1);
1155 c2 = fma((half8)a0.s2, b0, c2);
1156 c3 = fma((half8)a0.s3, b0, c3);
1157
1158 // Load values from matrix B (transposed)
1159 b0 = vload8(0, src_addr_b);
1160
1161 src_addr_b += 8 * H0;
1162
1163 c0 = fma((half8)a0.s4, b0, c0);
1164 c1 = fma((half8)a0.s5, b0, c1);
1165 c2 = fma((half8)a0.s6, b0, c2);
1166 c3 = fma((half8)a0.s7, b0, c3);
1167
1168 // Load values from matrix A (interleaved) and matrix B (transposed)
1169 a0 = vload8(0, src_addr_a);
1170 b0 = vload8(0, src_addr_b);
1171
1172 src_addr_a += 8 * V0;
1173 src_addr_b += 8 * H0;
1174
1175 c0 = fma((half8)a0.s0, b0, c0);
1176 c1 = fma((half8)a0.s1, b0, c1);
1177 c2 = fma((half8)a0.s2, b0, c2);
1178 c3 = fma((half8)a0.s3, b0, c3);
1179
1180 // Load values from matrix B (transposed)
1181 b0 = vload8(0, src_addr_b);
1182
1183 src_addr_b += 8 * H0;
1184
1185 c0 = fma((half8)a0.s4, b0, c0);
1186 c1 = fma((half8)a0.s5, b0, c1);
1187 c2 = fma((half8)a0.s6, b0, c2);
1188 c3 = fma((half8)a0.s7, b0, c3);
1189#else // V0 == 1
1190 // Load values from matrix A (interleaved) and matrix B (transposed)
1191 half4 a0 = vload4(0, src_addr_a);
1192 half8 b0 = vload8(0, src_addr_b);
1193
1194 src_addr_a += 4 * V0;
1195 src_addr_b += 8 * H0;
1196
1197 c0 = fma((half8)a0.s0, b0, c0);
1198 c1 = fma((half8)a0.s1, b0, c1);
1199 c2 = fma((half8)a0.s2, b0, c2);
1200 c3 = fma((half8)a0.s3, b0, c3);
1201
1202 // Load values from matrix A (interleaved) and matrix B (transposed)
1203 a0 = vload4(0, src_addr_a);
1204 b0 = vload8(0, src_addr_b);
1205
1206 src_addr_a += 4 * V0;
1207 src_addr_b += 8 * H0;
1208
1209 c0 = fma((half8)a0.s0, b0, c0);
1210 c1 = fma((half8)a0.s1, b0, c1);
1211 c2 = fma((half8)a0.s2, b0, c2);
1212 c3 = fma((half8)a0.s3, b0, c3);
1213
1214 // Load values from matrix A (interleaved) and matrix B (transposed)
1215 a0 = vload4(0, src_addr_a);
1216 b0 = vload8(0, src_addr_b);
1217
1218 src_addr_a += 4 * V0;
1219 src_addr_b += 8 * H0;
1220
1221 c0 = fma((half8)a0.s0, b0, c0);
1222 c1 = fma((half8)a0.s1, b0, c1);
1223 c2 = fma((half8)a0.s2, b0, c2);
1224 c3 = fma((half8)a0.s3, b0, c3);
1225
1226 // Load values from matrix A (interleaved) and matrix B (transposed)
1227 a0 = vload4(0, src_addr_a);
1228 b0 = vload8(0, src_addr_b);
1229
1230 src_addr_a += 4 * V0;
1231 src_addr_b += 8 * H0;
1232
1233 c0 = fma((half8)a0.s0, b0, c0);
1234 c1 = fma((half8)a0.s1, b0, c1);
1235 c2 = fma((half8)a0.s2, b0, c2);
1236 c3 = fma((half8)a0.s3, b0, c3);
1237#endif // V0 == 1
1238 }
1239
1240 for(; i < (int)K; ++i)
1241 {
1242 // Load values from matrix A (interleaved) and matrix B (transposed)
1243 half4 a0 = vload4(0, src_addr_a);
1244 half8 b0 = vload8(0, src_addr_b);
1245
1246 src_addr_a += 4 * V0;
1247 src_addr_b += 8 * H0;
1248
1249 c0 = fma((half8)a0.s0, b0, c0);
1250 c1 = fma((half8)a0.s1, b0, c1);
1251 c2 = fma((half8)a0.s2, b0, c2);
1252 c3 = fma((half8)a0.s3, b0, c3);
1253 }
1254
1255 // Compute destination address
1256 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1257
1258 // Compute dst address
1259 __global uchar *dst_addr = offset(&dst, 0, 0);
1260
1261 uint4 zout = 0;
1262
1263#if defined(REINTERPRET_OUTPUT_AS_3D)
1264 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1265 // in order to take into account the presence of possible cross plane paddings
1266 //
1267 // | |
1268 // | plane0 |
1269 // | |
1270 // |__________________|
1271 // |******************|
1272 // | cross_plane_pad |
1273 // |******************|
1274 // | |
1275 // | plane1 |
1276 // | |
1277 // |__________________|
1278
1279 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1280 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1281 zout = min(DEPTH_GEMM3D - 1, zout);
1282
1283 // Add offset due to the cross plane paddings
1284 zout *= (cross_plane_pad * dst_stride_y);
1285
1286 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1287 // multiply dst_stride_z by DEPTH_GEMM3D
1288 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1289#else // defined(REINTERPRET_OUTPUT_AS_3D)
1290 // Add offset for batched GEMM
1291 dst_addr += z * dst_stride_z;
1292#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1293
1294 // Multiply by the weight of matrix-matrix product and store the result
1295#if defined(ALPHA)
1296 SCALE_BLOCK(4, half, c, ALPHA);
1297#endif // defined(ALPHA)
1298
1299 // Add beta*bias
1300#if defined(BETA)
1301 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
1302
1303#if defined(BROADCAST_BIAS)
1304 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
1305
1306 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
1307
1308#ifndef UNIT_BETA
1309 SCALE_BLOCK(1, half, bias, BETA);
1310#endif // UNIT_BIAS
1311
1312 // c = c + bias[broadcasted]
1313 ADD_BLOCK_BROADCAST(4, c, bias0);
1314
1315#else // defined(BROADCAST_BIAS)
1316 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
1317 2) * src2_stride_z;
1318
1319 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
1320
1321#ifndef UNIT_BETA
1322 SCALE_BLOCK(4, half, bias, BETA);
1323#endif // UNIT_BIAS
1324
1325 // c = c + bias
1326 ADD_BLOCK(4, c, bias);
1327
1328#endif // defined(BROADCAST_BIAS)
1329#endif // defined(BETA)
1330
1331#if defined(ACTIVATION_TYPE)
1332 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c, A_VAL, B_VAL);
1333#endif // defined(ACTIVATION_TYPE)
1334
1335 // Store 4x8 block
SiCong Li0ea50e32020-11-05 09:18:11 +00001336 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
1337 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
1338 STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00001339}
1340
1341#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
1342
Manuel Bottini99b1a1c2021-03-29 17:15:21 +01001343#endif // defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0) && defined(IN1_DIM_X)
SiCong Li4abc9d12020-10-28 14:19:28 +00001344
SiCong Li0ea50e32020-11-05 09:18:11 +00001345#if defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
SiCong Li4abc9d12020-10-28 14:19:28 +00001346#if defined(DATA_TYPE)
1347#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, N0)
1348/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
1349 *
1350 * @note This OpenCL kernel works with floating point data types (F16/F32)
1351 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1352 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0
SiCong Li0ea50e32020-11-05 09:18:11 +00001353 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
1354 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
1355 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
1356 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00001357 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
1358 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
1359 *
1360 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1361 * The activation function is performed after the bias addition
1362 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1363 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1364 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1365 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1366 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1367 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1368 *
1369 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1370 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1371 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1372 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1373 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1374 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1375 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1376 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1377 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1378 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1379 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1380 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1381 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1382 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1383 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
1384 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1385 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
1386 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1387 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1388 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1389 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1390 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1391 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1392 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1393 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1394 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1395 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1396 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1397 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1398 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
1399 */
1400__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1401 IMAGE_DECLARATION(src1),
1402#if defined(BETA)
1403 IMAGE_DECLARATION(src2),
1404#endif // defined(BETA)
1405 IMAGE_DECLARATION(dst),
1406 uint src0_stride_z,
1407 uint src1_stride_z,
1408#if defined(BETA)
1409 uint src2_stride_z,
1410#endif //defined(BETA)
1411 uint dst_stride_z
1412#if defined(REINTERPRET_INPUT_AS_3D)
1413 ,
1414 uint src_cross_plane_pad
1415#endif // REINTERPRET_INPUT_AS_3D
1416#if defined(REINTERPRET_OUTPUT_AS_3D)
1417 ,
1418 uint dst_cross_plane_pad
1419#endif // REINTERPRET_OUTPUT_AS_3D
1420 )
1421{
1422 int idx = get_global_id(0) * N0;
1423
1424 // Compute starting address for matrix A and Matrix B
1425 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1426
1427 // Update address for the matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00001428 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00001429
1430 // Update address for the matrix B
1431 src_addr.s1 += idx * sizeof(DATA_TYPE);
1432
1433#if defined(REINTERPRET_INPUT_AS_3D)
1434 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1435 // in order to take into account the presence of possible cross plane paddings
1436 //
1437 // | |
1438 // | plane0 |
1439 // | |
1440 // |__________________|
1441 // |******************|
1442 // | cross_plane_pad |
1443 // |******************|
1444 // | |
1445 // | plane1 |
1446 // | |
1447 // |__________________|
1448
SiCong Li0ea50e32020-11-05 09:18:11 +00001449 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
1450 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00001451 zin = min(DEPTH_GEMM3D - 1, zin);
1452
1453 // Add offset due to the cross plane paddings
1454 zin *= (src_cross_plane_pad * src0_stride_y);
1455
1456 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1457 // multiply src0_stride_z by DEPTH_GEMM3D
1458 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1459
1460#else // defined(REINTERPRET_INPUT_AS_3D)
1461
1462 // Add offset for batched GEMM
1463 src_addr.s0 += get_global_id(2) * src0_stride_z;
1464
1465#endif // defined(REINTERPRET_INPUT_AS_3D)
1466
1467#if defined(MATRIX_B_DEPTH)
1468 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1469 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1470#else // defined(MATRIX_B_DEPTH)
1471 src_addr.s1 += get_global_id(2) * src1_stride_z;
1472#endif // defined(MATRIX_B_DEPTH)
1473
1474 int end_row_vec_a = src_addr.s0 + (K * sizeof(DATA_TYPE));
1475
1476 VECTOR_TYPE acc0 = 0.0f;
1477#if M0 > 1
1478 VECTOR_TYPE acc1 = 0.0f;
1479#endif // M0 > 1
1480#if M0 > 2
1481 VECTOR_TYPE acc2 = 0.0f;
1482#endif // M0 > 2
1483#if M0 > 3
1484 VECTOR_TYPE acc3 = 0.0f;
1485#endif // M0 > 3
1486
1487 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
1488 {
1489#if defined(REINTERPRET_INPUT_AS_3D)
1490 // Load values from matrix A
1491 LOAD_BLOCK(M0, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
1492#else // defined(REINTERPRET_INPUT_AS_3D)
1493 // Load values from matrix A
1494 VEC_DATA_TYPE(DATA_TYPE, 2)
1495 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1496#if M0 > 1
1497 VEC_DATA_TYPE(DATA_TYPE, 2)
1498 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1499#endif // M0 > 1
1500#if M0 > 2
1501 VEC_DATA_TYPE(DATA_TYPE, 2)
1502 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1503#endif // M0 > 2
1504#if M0 > 3
1505 VEC_DATA_TYPE(DATA_TYPE, 2)
1506 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1507#endif // M0 > 3
1508#endif // defined(REINTERPRET_INPUT_AS_3D)
1509
1510 // Load values from matrix B
1511 VECTOR_TYPE b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1512 VECTOR_TYPE b1 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
1513
1514 // Accumulate
1515 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1516 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1517#if M0 > 1
1518 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1519 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1520#endif // M0 > 1
1521#if M0 > 2
1522 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1523 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1524#endif // M0 > 2
1525#if M0 > 3
1526 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1527 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1528#endif // M0 > 3
1529 }
1530
1531 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
1532 {
1533#if defined(REINTERPRET_INPUT_AS_3D)
1534 // Load values from matrix A
1535 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1536#if M0 > 1
1537 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1538#endif // M0 > 1
1539#if M0 > 2
1540 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1541#endif // M0 > 2
1542#if M0 > 3
1543 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1544#endif // M0 > 3
1545#else // defined(REINTERPRET_INPUT_AS_3D)
1546 // Load values from matrix A
1547 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1548#if M0 > 1
1549 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1550#endif // M0 > 1
1551#if M0 > 2
1552 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1553#endif // M0 > 2
1554#if M0 > 3
1555 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1556#endif // M0 > 3
1557#endif // defined(REINTERPRET_INPUT_AS_3D)
1558
1559 // Load values from matrix B
1560 VECTOR_TYPE b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1561
1562 // Accumulate
1563 acc0 += b0 * (VECTOR_TYPE)a0;
1564#if M0 > 1
1565 acc1 += b0 * (VECTOR_TYPE)a1;
1566#endif // M0 > 1
1567#if M0 > 2
1568 acc2 += b0 * (VECTOR_TYPE)a2;
1569#endif // M0 > 2
1570#if M0 > 3
1571 acc3 += b0 * (VECTOR_TYPE)a3;
1572#endif // M0 > 3
1573 }
1574
1575 int z = get_global_id(2);
1576
SiCong Li4abc9d12020-10-28 14:19:28 +00001577 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00001578 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
1579 PARTIAL_STORE_M0)
1580 * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00001581
1582 uint4 zout = 0;
1583
1584#if defined(REINTERPRET_OUTPUT_AS_3D)
1585
1586 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1587 // in order to take into account the presence of possible cross plane paddings
1588 //
1589 // | |
1590 // | plane0 |
1591 // | |
1592 // |__________________|
1593 // |******************|
1594 // | cross_plane_pad |
1595 // |******************|
1596 // | |
1597 // | plane1 |
1598 // | |
1599 // |__________________|
1600
SiCong Li0ea50e32020-11-05 09:18:11 +00001601 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
1602 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00001603 zout = min(DEPTH_GEMM3D - 1, zout);
1604
1605 // Add offset due to the cross plane paddings
1606 zout *= (dst_cross_plane_pad * dst_stride_y);
1607
1608 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1609 // multiply dst_stride_z by DEPTH_GEMM3D
1610 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1611#else // defined(REINTERPRET_OUTPUT_AS_3D)
1612 // Add offset for batched GEMM
1613 dst_addr += z * dst_stride_z;
1614#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1615
1616 // Multiply by the weight of matrix-matrix product and store the result
1617#if defined(ALPHA)
1618 SCALE_BLOCK(M0, DATA_TYPE, acc, ALPHA);
1619#endif // defined(ALPHA)
1620
1621 // Add beta*bias
1622#if defined(BETA)
1623 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
1624
1625#if defined(BROADCAST_BIAS)
1626 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1627
1628 LOAD_BLOCK(1, N0, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
1629
1630#ifndef UNIT_BETA
1631 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1632#endif // UNIT_BIAS
1633
1634 // c = c + bias[broadcasted]
1635 ADD_BLOCK_BROADCAST(M0, acc, bias0);
1636
1637#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00001638 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
1639 PARTIAL_STORE_M0)
1640 * src2_stride_y)
1641 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00001642
1643 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
1644
1645#ifndef UNIT_BETA
1646 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1647#endif // UNIT_BIAS
1648
1649 // c = c + bias
1650 ADD_BLOCK(M0, acc, bias);
1651
1652#endif // defined(BROADCAST_BIAS)
1653#endif // defined(BETA)
1654
1655#if defined(ACTIVATION_TYPE)
1656 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, acc, A_VAL, B_VAL);
1657#endif // defined(ACTIVATION_TYPE)
1658
1659 // Store output block
SiCong Li0ea50e32020-11-05 09:18:11 +00001660 const bool cond_y = get_global_id(1) == 0;
1661 const bool cond_x = ((get_global_id(0) + 1) * N0 >= N);
1662 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00001663}
1664#endif // defined(DATA_TYPE)
1665
1666/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1667 *
1668 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1669 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
SiCong Li0ea50e32020-11-05 09:18:11 +00001670 * @note This kernel processed a fixed number of elements along x: -DN0=4.
1671 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
1672 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
1673 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
1674 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00001675 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
1676 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
1677 *
1678 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1679 * The activation function is performed after the bias addition
1680 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1681 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1682 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1683 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1684 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1685 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1686 *
1687 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1688 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1689 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1690 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1691 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1692 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1693 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1694 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1695 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1696 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1697 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1698 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1699 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1700 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1701 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
1702 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1703 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
1704 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1705 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1706 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1707 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1708 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1709 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1710 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1711 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1712 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1713 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1714 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1715 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1716 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1717 */
1718__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1719 IMAGE_DECLARATION(src1),
1720#if defined(BETA)
1721 IMAGE_DECLARATION(src2),
1722#endif // defined(BETA)
1723 IMAGE_DECLARATION(dst),
1724 uint src0_stride_z,
1725 uint src1_stride_z,
1726#if defined(BETA)
1727 uint src2_stride_z,
1728#endif //defined(BETA)
1729 uint dst_stride_z
1730#if defined(REINTERPRET_INPUT_AS_3D)
1731 ,
1732 uint src_cross_plane_pad
1733#endif // REINTERPRET_INPUT_AS_3D
1734#if defined(REINTERPRET_OUTPUT_AS_3D)
1735 ,
1736 uint dst_cross_plane_pad
1737#endif // REINTERPRET_OUTPUT_AS_3D
1738 )
1739{
1740 int idx = get_global_id(0) * N0;
1741
1742 // Compute starting address for matrix A and matrix B
1743 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1744
1745 // Update address for matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00001746 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00001747
1748 // Update address for matrix B
1749 src_addr.s1 += idx * sizeof(float);
1750
1751#if defined(REINTERPRET_INPUT_AS_3D)
1752 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1753 // in order to take into account the presence of possible cross plane paddings
1754 //
1755 // | |
1756 // | plane0 |
1757 // | |
1758 // |__________________|
1759 // |******************|
1760 // | cross_plane_pad |
1761 // |******************|
1762 // | |
1763 // | plane1 |
1764 // | |
1765 // |__________________|
1766
SiCong Li0ea50e32020-11-05 09:18:11 +00001767 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
1768 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00001769 zin = min(DEPTH_GEMM3D - 1, zin);
1770
1771 // Add offset due to the cross plane paddings
1772 zin *= (src_cross_plane_pad * src0_stride_y);
1773
1774 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1775 // multiply src0_stride_z by DEPTH_GEMM3D
1776 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1777
1778#else // defined(REINTERPRET_INPUT_AS_3D)
1779
1780 // Add offset for batched GEMM
1781 src_addr.s0 += get_global_id(2) * src0_stride_z;
1782
1783#endif // defined(REINTERPRET_INPUT_AS_3D)
1784
1785#if defined(MATRIX_B_DEPTH)
1786 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1787 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1788#else // defined(MATRIX_B_DEPTH)
1789 src_addr.s1 += get_global_id(2) * src1_stride_z;
1790#endif // defined(MATRIX_B_DEPTH)
1791
1792 // Initialize accumulators
1793 float4 acc0 = 0.0f;
1794
1795#if M0 > 1
1796 float4 acc1 = 0.0f;
1797#endif // M0 > 1
1798
1799#if M0 > 2
1800 float4 acc2 = 0.0f;
1801#endif // M0 > 2
1802
1803#if M0 > 3
1804 float4 acc3 = 0.0f;
1805#endif // M0 > 3
1806
1807 // A and B src indices get incremented at the same time.
1808 int i = 0;
1809 for(; i <= ((int)K - 4); i += 4)
1810 {
1811#if defined(REINTERPRET_INPUT_AS_3D)
1812 // Load values from matrix A and matrix B
1813 LOAD_BLOCK(M0, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
1814#else // defined(REINTERPRET_INPUT_AS_3D)
1815 // Load values from matrix A and matrix B
1816 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1817#if M0 > 1
1818 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1819#endif // M0 > 1
1820#if M0 > 2
1821 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1822#endif // M0 > 2
1823#if M0 > 3
1824 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1825#endif // M0 > 3
1826#endif // defined(REINTERPRET_INPUT_AS_3D)
1827
1828 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1829 src_addr.s1 += src1_stride_y;
1830
1831 // Multiply and accumulate
1832 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
1833 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
1834 acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
1835 acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
1836
1837#if M0 > 1
1838
1839 acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
1840 acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
1841 acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
1842 acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
1843
1844#endif // M0 > 1
1845#if M0 > 2
1846
1847 acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
1848 acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
1849 acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
1850 acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
1851
1852#endif // M0 > 2
1853#if M0 > 3
1854
1855 acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
1856 acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
1857 acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
1858 acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
1859#endif // M0 > 3
1860
1861 // Load values from matrix A and matrix B
1862 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1863 src_addr.s1 += src1_stride_y;
1864
1865 // Multiply and accumulate
1866 acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
1867 acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
1868 acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
1869 acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
1870
1871#if M0 > 1
1872
1873 acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
1874 acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
1875 acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
1876 acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
1877
1878#endif // M0 > 1
1879#if M0 > 2
1880
1881 acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
1882 acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
1883 acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
1884 acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
1885
1886#endif // M0 > 2
1887#if M0 > 3
1888
1889 acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
1890 acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
1891 acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
1892 acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
1893#endif // M0 > 3
1894
1895 // Load values from matrix A and matrix B
1896 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1897 src_addr.s1 += src1_stride_y;
1898
1899 // Multiply and accumulate
1900 acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
1901 acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
1902 acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
1903 acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
1904
1905#if M0 > 1
1906
1907 acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
1908 acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
1909 acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
1910 acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
1911
1912#endif // M0 > 1
1913#if M0 > 2
1914
1915 acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
1916 acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
1917 acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
1918 acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
1919
1920#endif // M0 > 2
1921#if M0 > 3
1922
1923 acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
1924 acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
1925 acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
1926 acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
1927#endif // M0 > 3
1928
1929 // Load values from matrix A and matrix B
1930 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1931 src_addr.s1 += src1_stride_y;
1932
1933 // Multiply and accumulate
1934 acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
1935 acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
1936 acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
1937 acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
1938
1939#if M0 > 1
1940
1941 acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
1942 acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
1943 acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
1944 acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
1945
1946#endif // M0 > 1
1947#if M0 > 2
1948
1949 acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
1950 acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
1951 acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
1952 acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
1953
1954#endif // M0 > 2
1955#if M0 > 3
1956
1957 acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
1958 acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
1959 acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
1960 acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
1961#endif // M0 > 3
1962
1963 src_addr.s0 += 4 * sizeof(float);
1964 }
1965
1966 for(; i < (int)K; ++i)
1967 {
1968#if defined(REINTERPRET_INPUT_AS_3D)
1969 // Load values from matrix A
1970 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1971#if M0 > 1
1972 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1973#endif // M0 > 1
1974#if M0 > 2
1975 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1976#endif // M0 > 2
1977#if M0 > 3
1978 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1979#endif // M0 > 3
1980#else // defined(REINTERPRET_INPUT_AS_3D)
1981 // Load values from matrix A
1982 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1983#if M0 > 1
1984 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1985#endif // M0 > 1
1986#if M0 > 2
1987 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1988#endif // M0 > 2
1989#if M0 > 3
1990 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1991#endif // M0 > 3
1992#endif // defined(REINTERPRET_INPUT_AS_3D)
1993
1994 // Load values from matrix B
1995 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1996 src_addr.s1 += src1_stride_y;
1997
1998 // Multiply and accumulate
1999 acc0.s0 = fma(a0, b0.s0, acc0.s0);
2000 acc0.s1 = fma(a0, b0.s1, acc0.s1);
2001 acc0.s2 = fma(a0, b0.s2, acc0.s2);
2002 acc0.s3 = fma(a0, b0.s3, acc0.s3);
2003#if M0 > 1
2004 acc1.s0 = fma(a1, b0.s0, acc1.s0);
2005 acc1.s1 = fma(a1, b0.s1, acc1.s1);
2006 acc1.s2 = fma(a1, b0.s2, acc1.s2);
2007 acc1.s3 = fma(a1, b0.s3, acc1.s3);
2008#endif // M0 > 1
2009#if M0 > 2
2010 acc2.s0 = fma(a2, b0.s0, acc2.s0);
2011 acc2.s1 = fma(a2, b0.s1, acc2.s1);
2012 acc2.s2 = fma(a2, b0.s2, acc2.s2);
2013 acc2.s3 = fma(a2, b0.s3, acc2.s3);
2014#endif // M0 > 2
2015#if M0 > 3
2016 acc3.s0 = fma(a3, b0.s0, acc3.s0);
2017 acc3.s1 = fma(a3, b0.s1, acc3.s1);
2018 acc3.s2 = fma(a3, b0.s2, acc3.s2);
2019 acc3.s3 = fma(a3, b0.s3, acc3.s3);
2020#endif // M0 > 3
2021
2022 src_addr.s0 += sizeof(float);
2023 }
2024
2025 int z = get_global_id(2);
2026
SiCong Li4abc9d12020-10-28 14:19:28 +00002027 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00002028 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
Manuel Bottini99b1a1c2021-03-29 17:15:21 +01002029 PARTIAL_STORE_M0)
2030 * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00002031
2032 uint4 zout = 0;
2033
2034#if defined(REINTERPRET_OUTPUT_AS_3D)
2035 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2036 // in order to take into account the presence of possible cross plane paddings
2037 //
2038 // | |
2039 // | plane0 |
2040 // | |
2041 // |__________________|
2042 // |******************|
2043 // | cross_plane_pad |
2044 // |******************|
2045 // | |
2046 // | plane1 |
2047 // | |
2048 // |__________________|
2049
SiCong Li0ea50e32020-11-05 09:18:11 +00002050 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
2051 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002052 zout = min(DEPTH_GEMM3D - 1, zout);
2053
2054 // Add offset due to the cross plane paddings
2055 zout *= (dst_cross_plane_pad * dst_stride_y);
2056
2057 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2058 // multiply dst_stride_z by DEPTH_GEMM3D
2059 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2060#else // defined(REINTERPRET_OUTPUT_AS_3D)
2061 // Add offset for batched GEMM
2062 dst_addr += z * dst_stride_z;
2063#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2064
2065 // Multiply by the weight of matrix-matrix product and store the result
2066#if defined(ALPHA)
2067 SCALE_BLOCK(M0, float, acc, ALPHA);
2068#endif // defined(ALPHA)
2069
2070 // Add beta*bias
2071#if defined(BETA)
2072 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2073
2074#if defined(BROADCAST_BIAS)
2075 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
2076
2077 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2078
2079#ifndef UNIT_BETA
2080 SCALE_BLOCK(1, float, bias, BETA);
2081#endif // UNIT_BIAS
2082
2083 // acc = acc + bias[broadcasted]
2084 ADD_BLOCK_BROADCAST(M0, acc, bias0);
2085
2086#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00002087 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
2088 PARTIAL_STORE_M0)
2089 * src2_stride_y)
2090 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00002091
2092 LOAD_BLOCK(M0, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2093
2094#ifndef UNIT_BETA
2095 SCALE_BLOCK(M0, float, bias, BETA);
2096#endif // UNIT_BIAS
2097
2098 // acc = acc + bias
2099 ADD_BLOCK(M0, acc, bias);
2100
2101#endif // defined(BROADCAST_BIAS)
2102#endif // defined(BETA)
2103
2104#if defined(ACTIVATION_TYPE)
2105 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, float, VEC_SIZE, acc, A_VAL, B_VAL);
2106#endif // defined(ACTIVATION_TYPE)
2107
2108 // Store the output block
SiCong Li0ea50e32020-11-05 09:18:11 +00002109 const bool cond_y = get_global_id(1) == 0;
2110 const bool cond_x = ((get_global_id(0) + 1) * 4 >= N);
2111 STORE_BLOCK_BOUNDARY_AWARE(M0, 4, float, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00002112}
2113
2114/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
2115 *
2116 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
2117 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
2118 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
SiCong Li0ea50e32020-11-05 09:18:11 +00002119 * @note This kernel processed a fixed number of elements along x: -DN0=2.
2120 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
2121 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
2122 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
2123 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00002124 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2125 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
2126 *
2127 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2128 * The activation function is performed after the bias addition
2129 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2130 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2131 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2132 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2133 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2134 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2135 *
2136 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2137 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2138 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2139 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2140 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2141 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2142 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2143 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2144 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2145 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2146 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2147 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2148 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2149 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2150 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2151 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2152 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2153 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2154 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2155 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2156 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2157 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2158 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2159 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2160 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2161 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2162 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2163 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2164 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2165 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2166 */
2167__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
2168 IMAGE_DECLARATION(src1),
2169#if defined(BETA)
2170 IMAGE_DECLARATION(src2),
2171#endif // defined(BETA)
2172 IMAGE_DECLARATION(dst),
2173 uint src0_stride_z,
2174 uint src1_stride_z,
2175#if defined(BETA)
2176 uint src2_stride_z,
2177#endif //defined(BETA)
2178 uint dst_stride_z
2179#if defined(REINTERPRET_INPUT_AS_3D)
2180 ,
2181 uint src_cross_plane_pad
2182#endif // REINTERPRET_INPUT_AS_3D
2183#if defined(REINTERPRET_OUTPUT_AS_3D)
2184 ,
2185 uint dst_cross_plane_pad
2186#endif // REINTERPRET_OUTPUT_AS_3D
2187 )
2188{
2189 // Requires 2 N0, C vect2, A vect4, B (2 vload2) // to fix for M0 > 1
2190 int idx = get_global_id(0) * N0;
2191
2192 // Compute starting address for matrix A and Matrix B
2193 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2194
2195 // Update address for the matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00002196 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00002197
2198 // Update address for the matrix B
2199 src_addr.s1 += idx * sizeof(float);
2200
2201#if defined(REINTERPRET_INPUT_AS_3D)
2202 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2203 // in order to take into account the presence of possible cross plane paddings
2204 //
2205 // | |
2206 // | plane0 |
2207 // | |
2208 // |__________________|
2209 // |******************|
2210 // | cross_plane_pad |
2211 // |******************|
2212 // | |
2213 // | plane1 |
2214 // | |
2215 // |__________________|
2216
SiCong Li0ea50e32020-11-05 09:18:11 +00002217 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
2218 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002219 zin = min(DEPTH_GEMM3D - 1, zin);
2220
2221 // Add offset due to the cross plane paddings
2222 zin *= (src_cross_plane_pad * src0_stride_y);
2223
2224 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2225 // multiply src0_stride_z by DEPTH_GEMM3D
2226 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2227
2228#else // defined(REINTERPRET_INPUT_AS_3D)
2229
2230 // Add offset for batched GEMM
2231 src_addr.s0 += get_global_id(2) * src0_stride_z;
2232
2233#endif // defined(REINTERPRET_INPUT_AS_3D)
2234
2235#if defined(MATRIX_B_DEPTH)
2236 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2237 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2238#else // defined(MATRIX_B_DEPTH)
2239 src_addr.s1 += get_global_id(2) * src1_stride_z;
2240#endif // defined(MATRIX_B_DEPTH)
2241
2242 // Initialize accumulators
2243 float2 acc0 = 0.0f;
2244#if M0 > 1
2245 float2 acc1 = 0.0f;
2246#endif // M0 > 1
2247#if M0 > 2
2248 float2 acc2 = 0.0f;
2249#endif // M0 > 2
2250#if M0 > 3
2251 float2 acc3 = 0.0f;
2252#endif // M0 > 3
2253
2254 // A and B src indices get incremented at the same time.
2255 int i = 0;
2256 for(; i <= ((int)K - 8); i += 8)
2257 {
2258#if defined(REINTERPRET_INPUT_AS_3D)
2259 // Load values from matrix A
2260 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
2261#else // defined(REINTERPRET_INPUT_AS_3D)
2262 // Load values from matrix A
2263 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
2264#endif // defined(REINTERPRET_INPUT_AS_3D)
2265
2266 // Load values from matrix B
2267 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2268 src_addr.s1 += src1_stride_y;
2269 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2270 src_addr.s1 += src1_stride_y;
2271 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2272 src_addr.s1 += src1_stride_y;
2273 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2274 src_addr.s1 += src1_stride_y;
2275 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2276 src_addr.s1 += src1_stride_y;
2277 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2278 src_addr.s1 += src1_stride_y;
2279 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2280 src_addr.s1 += src1_stride_y;
2281 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2282 src_addr.s1 += src1_stride_y;
2283
2284 // Multiply and accumulate
2285 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
2286 acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
2287 acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
2288 acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
2289 acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
2290 acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
2291 acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
2292 acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
2293
2294 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
2295 acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
2296 acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
2297 acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
2298 acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
2299 acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
2300 acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
2301 acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
2302
2303#if M0 > 1
2304#if defined(REINTERPRET_INPUT_AS_3D)
2305 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2306#else // defined(REINTERPRET_INPUT_AS_3D)
2307 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2308#endif // defined(REINTERPRET_INPUT_AS_3D)
2309 acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
2310 acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
2311 acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
2312 acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
2313 acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
2314 acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
2315 acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
2316 acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
2317
2318 acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
2319 acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
2320 acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
2321 acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
2322 acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
2323 acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
2324 acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
2325 acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
2326#endif // M0 > 1
2327#if M0 > 2
2328#if defined(REINTERPRET_INPUT_AS_3D)
2329 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2330#else // defined(REINTERPRET_INPUT_AS_3D)
2331 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2332#endif // defined(REINTERPRET_INPUT_AS_3D)
2333 acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
2334 acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
2335 acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
2336 acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
2337 acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
2338 acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
2339 acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
2340 acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
2341
2342 acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
2343 acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
2344 acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
2345 acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
2346 acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
2347 acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
2348 acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
2349 acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
2350#endif // M0 > 2
2351#if M0 > 3
2352#if defined(REINTERPRET_INPUT_AS_3D)
2353 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2354#else // defined(REINTERPRET_INPUT_AS_3D)
2355 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2356#endif // defined(REINTERPRET_INPUT_AS_3D)
2357 acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
2358 acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
2359 acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
2360 acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
2361 acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
2362 acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
2363 acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
2364 acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
2365
2366 acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
2367 acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
2368 acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
2369 acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
2370 acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
2371 acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
2372 acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
2373 acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
2374#endif // M0 > 3
2375
2376 src_addr.s0 += sizeof(float) * 8;
2377 }
2378 // float size increment
2379 for(; i < (int)K; ++i)
2380 {
2381#if defined(REINTERPRET_INPUT_AS_3D)
2382 // Load values from matrix A
2383 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2384#if M0 > 1
2385 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2386#endif // M0 > 1
2387#if M0 > 2
2388 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2389#endif // M0 > 2
2390#if M0 > 3
2391 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2392#endif // M0 > 3
2393#else // defined(REINTERPRET_INPUT_AS_3D)
2394 // Load values from matrix A
2395 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2396#if M0 > 1
2397 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2398#endif // M0 > 1
2399#if M0 > 2
2400 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2401#endif // M0 > 2
2402#if M0 > 3
2403 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2404#endif // M0 > 3
2405#endif // defined(REINTERPRET_INPUT_AS_3D)
2406
2407 // Load values from matrix B
2408 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2409 src_addr.s1 += src1_stride_y;
2410
2411 // Multiply and accumulate
2412 acc0.s0 = fma(a0, b0.s0, acc0.s0);
2413 acc0.s1 = fma(a0, b0.s1, acc0.s1);
2414#if M0 > 1
2415 acc1.s0 = fma(a1, b0.s0, acc1.s0);
2416 acc1.s1 = fma(a1, b0.s1, acc1.s1);
2417#endif // M0 > 1
2418#if M0 > 2
2419 acc2.s0 = fma(a2, b0.s0, acc2.s0);
2420 acc2.s1 = fma(a2, b0.s1, acc2.s1);
2421#endif // M0 > 2
2422#if M0 > 3
2423 acc3.s0 = fma(a3, b0.s0, acc3.s0);
2424 acc3.s1 = fma(a3, b0.s1, acc3.s1);
2425#endif // M0 > 3
2426
2427 src_addr.s0 += sizeof(float);
2428 }
2429
2430 int z = get_global_id(2);
2431
SiCong Li4abc9d12020-10-28 14:19:28 +00002432 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00002433 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
Manuel Bottini99b1a1c2021-03-29 17:15:21 +01002434 PARTIAL_STORE_M0)
2435 * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00002436
2437 uint4 zout = 0;
2438
2439#if defined(REINTERPRET_OUTPUT_AS_3D)
2440
2441 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2442 // in order to take into account the presence of possible cross plane paddings
2443 //
2444 // | |
2445 // | plane0 |
2446 // | |
2447 // |__________________|
2448 // |******************|
2449 // | cross_plane_pad |
2450 // |******************|
2451 // | |
2452 // | plane1 |
2453 // | |
2454 // |__________________|
2455
SiCong Li0ea50e32020-11-05 09:18:11 +00002456 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
2457 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002458 zout = min(DEPTH_GEMM3D - 1, zout);
2459
2460 // Add offset due to the cross plane paddings
2461 zout *= (dst_cross_plane_pad * dst_stride_y);
2462
2463 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2464 // multiply dst_stride_z by DEPTH_GEMM3D
2465 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2466#else // defined(REINTERPRET_OUTPUT_AS_3D)
2467 // Add offset for batched GEMM
2468 dst_addr += z * dst_stride_z;
2469#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2470
2471 // Multiply by the weight of matrix-matrix product and store the result
2472#if defined(ALPHA)
2473 SCALE_BLOCK(M0, float, acc, ALPHA);
2474#endif // defined(ALPHA)
2475
2476 // Add beta*bias
2477#if defined(BETA)
2478 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2479
2480#if defined(BROADCAST_BIAS)
2481 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
2482
2483 LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
2484
2485#ifndef UNIT_BETA
2486 SCALE_BLOCK(1, float, bias, BETA);
2487#endif // UNIT_BIAS
2488
2489 // acc = acc + bias[broadcasted]
2490 ADD_BLOCK_BROADCAST(M0, acc, bias0);
2491
2492#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00002493 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
2494 PARTIAL_STORE_M0)
2495 * src2_stride_y)
2496 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00002497
2498 LOAD_BLOCK(M0, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
2499
2500#ifndef UNIT_BETA
2501 SCALE_BLOCK(M0, float, bias, BETA);
2502#endif // UNIT_BIAS
2503
2504 // acc = acc + bias
2505 ADD_BLOCK(M0, acc, bias);
2506
2507#endif // defined(BROADCAST_BIAS)
2508#endif // defined(BETA)
2509
2510#if defined(ACTIVATION_TYPE)
2511 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, float, VEC_SIZE, acc, A_VAL, B_VAL);
2512#endif // defined(ACTIVATION_TYPE)
2513
2514 // Store the output block
SiCong Li0ea50e32020-11-05 09:18:11 +00002515 const bool cond_y = get_global_id(1) == 0;
2516 const bool cond_x = ((get_global_id(0) + 1) * 2 >= N);
2517 STORE_BLOCK_BOUNDARY_AWARE(M0, 2, float, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00002518}
2519
2520#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
2521/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2522 *
2523 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
2524 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
SiCong Li0ea50e32020-11-05 09:18:11 +00002525 * @note This kernel processed a fixed number of elements along x: -DN0=8.
2526 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
2527 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
2528 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
2529 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00002530 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2531 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
2532 *
2533 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2534 * The activation function is performed after the bias addition
2535 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2536 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2537 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2538 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2539 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2540 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2541 *
2542 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2543 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2544 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2545 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2546 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2547 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2548 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2549 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2550 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2551 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2552 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2553 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2554 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2555 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2556 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2557 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2558 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2559 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2560 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2561 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2562 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2563 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2564 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2565 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2566 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2567 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2568 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2569 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2570 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2571 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2572 */
2573__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
2574 IMAGE_DECLARATION(src1),
2575#if defined(BETA)
2576 IMAGE_DECLARATION(src2),
2577#endif // defined(BETA)
2578 IMAGE_DECLARATION(dst),
2579 uint src0_stride_z,
2580 uint src1_stride_z,
2581#if defined(BETA)
2582 uint src2_stride_z,
2583#endif //defined(BETA)
2584 uint dst_stride_z
2585#if defined(REINTERPRET_INPUT_AS_3D)
2586 ,
2587 uint src_cross_plane_pad
2588#endif // REINTERPRET_INPUT_AS_3D
2589#if defined(REINTERPRET_OUTPUT_AS_3D)
2590 ,
2591 uint dst_cross_plane_pad
2592#endif // REINTERPRET_OUTPUT_AS_3D
2593 )
2594{
2595 int idx = get_global_id(0) * N0;
2596
2597 // Compute starting address for matrix A and Matrix B
2598 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2599
2600 // Update address for the matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00002601 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00002602
2603 // Update address for the matrix B
2604 src_addr.s1 += idx * sizeof(half);
2605
2606#if defined(REINTERPRET_INPUT_AS_3D)
2607 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2608 // in order to take into account the presence of possible cross plane paddings
2609 //
2610 // | |
2611 // | plane0 |
2612 // | |
2613 // |__________________|
2614 // |******************|
2615 // | cross_plane_pad |
2616 // |******************|
2617 // | |
2618 // | plane1 |
2619 // | |
2620 // |__________________|
2621
SiCong Li0ea50e32020-11-05 09:18:11 +00002622 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
2623 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002624 zin = min(DEPTH_GEMM3D - 1, zin);
2625
2626 // Add offset due to the cross plane paddings
2627 zin *= (src_cross_plane_pad * src0_stride_y);
2628
2629 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2630 // multiply src0_stride_z by DEPTH_GEMM3D
2631 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2632
2633#else // defined(REINTERPRET_INPUT_AS_3D)
2634
2635 // Add offset for batched GEMM
2636 src_addr.s0 += get_global_id(2) * src0_stride_z;
2637
2638#endif // defined(REINTERPRET_INPUT_AS_3D)
2639
2640#if defined(MATRIX_B_DEPTH)
2641 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2642 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2643#else // defined(MATRIX_B_DEPTH)
2644 src_addr.s1 += get_global_id(2) * src1_stride_z;
2645#endif // defined(MATRIX_B_DEPTH)
2646
2647 float8 acc0 = 0.0h;
2648#if M0 > 1
2649 float8 acc1 = 0.0h;
2650#endif // M0 > 1
2651#if M0 > 2
2652 float8 acc2 = 0.0h;
2653#endif // M0 > 2
2654#if M0 > 3
2655 float8 acc3 = 0.0h;
2656#endif // M0 > 3
2657
2658 int i = 0;
2659 for(; i <= ((int)K - 4); i += 4)
2660 {
2661#if defined(REINTERPRET_INPUT_AS_3D)
2662 // Load values from matrix A
2663 LOAD_BLOCK(M0, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
2664#else // defined(REINTERPRET_INPUT_AS_3D)
2665 // Load values from matrix A
2666 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2667#if M0 > 1
2668 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2669#endif // M0 > 1
2670#if M0 > 2
2671 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2672#endif // M0 > 2
2673#if M0 > 3
2674 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2675#endif // M0 > 3
2676#endif // defined(REINTERPRET_INPUT_AS_3D)
2677
2678 // Load values from matrix B
2679 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2680 src_addr.s1 += src1_stride_y;
2681
2682 // Accumulate
2683 acc0 = fma(b0, (float8)a0.s0, acc0);
2684#if M0 > 1
2685 acc1 = fma(b0, (float8)a1.s0, acc1);
2686#endif // M0 > 1
2687#if M0 > 2
2688 acc2 = fma(b0, (float8)a2.s0, acc2);
2689#endif // M0 > 2
2690#if M0 > 3
2691 acc3 = fma(b0, (float8)a3.s0, acc3);
2692#endif // M0 > 3
2693
2694 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2695 src_addr.s1 += src1_stride_y;
2696 acc0 = fma(b0, (float8)a0.s1, acc0);
2697#if M0 > 1
2698 acc1 = fma(b0, (float8)a1.s1, acc1);
2699#endif // M0 > 1
2700#if M0 > 2
2701 acc2 = fma(b0, (float8)a2.s1, acc2);
2702#endif // M0 > 2
2703#if M0 > 3
2704 acc3 = fma(b0, (float8)a3.s1, acc3);
2705#endif // M0 > 3
2706
2707 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2708 src_addr.s1 += src1_stride_y;
2709 acc0 = fma(b0, (float8)a0.s2, acc0);
2710#if M0 > 1
2711 acc1 = fma(b0, (float8)a1.s2, acc1);
2712#endif // M0 > 1
2713#if M0 > 2
2714 acc2 = fma(b0, (float8)a2.s2, acc2);
2715#endif // M0 > 2
2716#if M0 > 3
2717 acc3 = fma(b0, (float8)a3.s2, acc3);
2718#endif // M0 > 3
2719
2720 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2721 src_addr.s1 += src1_stride_y;
2722 acc0 = fma(b0, (float8)a0.s3, acc0);
2723#if M0 > 1
2724 acc1 = fma(b0, (float8)a1.s3, acc1);
2725#endif // M0 > 1
2726#if M0 > 2
2727 acc2 = fma(b0, (float8)a2.s3, acc2);
2728#endif // M0 > 2
2729#if M0 > 3
2730 acc3 = fma(b0, (float8)a3.s3, acc3);
2731#endif // M0 > 3
2732
2733 src_addr.s0 += 4 * sizeof(half);
2734 }
2735
2736 for(; i < (int)K; ++i)
2737 {
2738#if defined(REINTERPRET_INPUT_AS_3D)
2739 // Load values from matrix A
2740 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2741#if M0 > 1
2742 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2743#endif // M0 > 1
2744#if M0 > 2
2745 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2746#endif // M0 > 2
2747#if M0 > 3
2748 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2749#endif // M0 > 3
2750#else // defined(REINTERPRET_INPUT_AS_3D)
2751 // Load values from matrix A
2752 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2753#if M0 > 1
2754 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2755#endif // M0 > 1
2756#if M0 > 2
2757 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2758#endif // M0 > 2
2759#if M0 > 3
2760 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2761#endif // M0 > 3
2762#endif // defined(REINTERPRET_INPUT_AS_3D)
2763
2764 // Load values from matrix B
2765 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2766
2767 src_addr += (int2)(sizeof(half), src1_stride_y);
2768
2769 // Accumulate
2770 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
2771#if M0 > 1
2772 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
2773#endif // M0 > 1
2774#if M0 > 2
2775 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
2776#endif // M0 > 2
2777#if M0 > 3
2778 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
2779#endif // M0 > 3
2780 }
2781
2782 int z = get_global_id(2);
2783
SiCong Li4abc9d12020-10-28 14:19:28 +00002784 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00002785 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00002786
2787 uint4 zout = 0;
2788
2789#if defined(REINTERPRET_OUTPUT_AS_3D)
2790
2791 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2792 // in order to take into account the presence of possible cross plane paddings
2793 //
2794 // | |
2795 // | plane0 |
2796 // | |
2797 // |__________________|
2798 // |******************|
2799 // | cross_plane_pad |
2800 // |******************|
2801 // | |
2802 // | plane1 |
2803 // | |
2804 // |__________________|
2805
SiCong Li0ea50e32020-11-05 09:18:11 +00002806 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
2807 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002808 zout = min(DEPTH_GEMM3D - 1, zout);
2809
2810 // Add offset due to the cross plane paddings
2811 zout *= (dst_cross_plane_pad * dst_stride_y);
2812
2813 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2814 // multiply dst_stride_z by DEPTH_GEMM3D
2815 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2816#else // defined(REINTERPRET_OUTPUT_AS_3D)
2817 // Add offset for batched GEMM
2818 dst_addr += z * dst_stride_z;
2819#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2820
2821 // Multiply by the weight of matrix-matrix product and store the result
2822#if defined(ALPHA)
2823 SCALE_BLOCK(M0, float, acc, ALPHA);
2824#endif // defined(ALPHA)
2825
2826#if defined(BETA)
2827 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2828
2829#if defined(BROADCAST_BIAS)
2830 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
2831
2832 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
2833
2834 float8 bias_f0 = convert_float8(bias0);
2835
2836#ifndef UNIT_BETA
2837 SCALE_BLOCK(1, float, bias_f, BETA);
2838#endif // UNIT_BIAS
2839
2840 // acc = acc + bias[broadcasted]
2841 ADD_BLOCK_BROADCAST(M0, acc, bias_f0);
2842
2843#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00002844 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
2845 PARTIAL_STORE_M0)
2846 * src2_stride_y)
2847 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00002848
2849 LOAD_BLOCK(M0, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
2850
2851 float8 bias_f0 = convert_float8(bias0);
2852#if M0 > 1
2853 float8 bias_f1 = convert_float8(bias1);
2854#endif // M0 > 1
2855#if M0 > 2
2856 float8 bias_f2 = convert_float8(bias2);
2857#endif // M0 > 2
2858#if M0 > 3
2859 float8 bias_f3 = convert_float8(bias3);
2860#endif // M0 > 3
2861
2862#ifndef UNIT_BETA
2863 SCALE_BLOCK(M0, float, bias_f, BETA);
2864#endif // UNIT_BIAS
2865
2866 // acc = acc + bias
2867 ADD_BLOCK(M0, acc, bias_f);
2868
2869#endif // defined(BROADCAST_BIAS)
2870#endif // defined(BETA)
2871
2872 half8 acc_h0 = convert_half8(acc0);
2873#if M0 > 1
2874 half8 acc_h1 = convert_half8(acc1);
2875#endif // M0 > 1
2876#if M0 > 2
2877 half8 acc_h2 = convert_half8(acc2);
2878#endif // M0 > 2
2879#if M0 > 3
2880 half8 acc_h3 = convert_half8(acc3);
2881#endif // M0 > 3
2882
2883#if defined(ACTIVATION_TYPE)
2884 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, half, VEC_SIZE, acc_h, A_VAL, B_VAL);
2885#endif // defined(ACTIVATION_TYPE)
2886
2887 // Store the output block
SiCong Li0ea50e32020-11-05 09:18:11 +00002888 const bool cond_y = get_global_id(1) == 0;
2889 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
2890 STORE_BLOCK_BOUNDARY_AWARE(M0, 8, half, acc_h, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00002891}
2892
2893/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2894 *
2895 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
2896 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
SiCong Li0ea50e32020-11-05 09:18:11 +00002897 * @note This kernel processed a fixed number of elements along x: -DN0=8.
2898 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
2899 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
2900 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
2901 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00002902 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2903 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
2904 *
2905 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2906 * The activation function is performed after the bias addition
2907 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2908 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2909 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2910 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2911 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2912 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2913 *
2914 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2915 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2916 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2917 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2918 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2919 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2920 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2921 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2922 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2923 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2924 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2925 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2926 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2927 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2928 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2929 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2930 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2931 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2932 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2933 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2934 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2935 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2936 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2937 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2938 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2939 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2940 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2941 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2942 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2943 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2944 */
2945__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
2946 IMAGE_DECLARATION(src1),
2947#if defined(BETA)
2948 IMAGE_DECLARATION(src2),
2949#endif // defined(BETA)
2950 IMAGE_DECLARATION(dst),
2951 uint src0_stride_z,
2952 uint src1_stride_z,
2953#if defined(BETA)
2954 uint src2_stride_z,
2955#endif //defined(BETA)
2956 uint dst_stride_z
2957#if defined(REINTERPRET_INPUT_AS_3D)
2958 ,
2959 uint src_cross_plane_pad
2960#endif // REINTERPRET_INPUT_AS_3D
2961#if defined(REINTERPRET_OUTPUT_AS_3D)
2962 ,
2963 uint dst_cross_plane_pad
2964#endif // REINTERPRET_OUTPUT_AS_3D
2965 )
2966{
2967 int idx = get_global_id(0) * N0;
2968
2969 // Compute starting address for matrix A and Matrix B
2970 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2971
2972 // Update address for the matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00002973 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00002974
2975 // Update address for the matrix B
2976 src_addr.s1 += idx * sizeof(half);
2977
2978#if defined(REINTERPRET_INPUT_AS_3D)
2979 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2980 // in order to take into account the presence of possible cross plane paddings
2981 //
2982 // | |
2983 // | plane0 |
2984 // | |
2985 // |__________________|
2986 // |******************|
2987 // | cross_plane_pad |
2988 // |******************|
2989 // | |
2990 // | plane1 |
2991 // | |
2992 // |__________________|
2993
SiCong Li0ea50e32020-11-05 09:18:11 +00002994 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
2995 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002996 zin = min(DEPTH_GEMM3D - 1, zin);
2997
2998 // Add offset due to the cross plane paddings
2999 zin *= (src_cross_plane_pad * src0_stride_y);
3000
3001 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3002 // multiply src0_stride_z by DEPTH_GEMM3D
3003 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
3004
3005#else // defined(REINTERPRET_INPUT_AS_3D)
3006
3007 // Add offset for batched GEMM
3008 src_addr.s0 += get_global_id(2) * src0_stride_z;
3009
3010#endif // defined(REINTERPRET_INPUT_AS_3D)
3011
3012#if defined(MATRIX_B_DEPTH)
3013 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3014 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3015#else // defined(MATRIX_B_DEPTH)
3016 src_addr.s1 += get_global_id(2) * src1_stride_z;
3017#endif // defined(MATRIX_B_DEPTH)
3018
3019 half8 acc0 = 0.0h;
3020#if M0 > 1
3021 half8 acc1 = 0.0h;
3022#endif // M0 > 1
3023#if M0 > 2
3024 half8 acc2 = 0.0h;
3025#endif // M0 > 2
3026#if M0 > 3
3027 half8 acc3 = 0.0h;
3028#endif // M0 > 3
3029
3030 int i = 0;
3031 for(; i <= ((int)K - 4); i += 4)
3032 {
3033#if defined(REINTERPRET_INPUT_AS_3D)
3034 // Load values from matrix A
3035 LOAD_BLOCK(M0, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
3036#else // defined(REINTERPRET_INPUT_AS_3D)
3037 // Load values from matrix A
3038 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3039#if M0 > 1
3040 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3041#endif // M0 > 1
3042#if M0 > 2
3043 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3044#endif // M0 > 2
3045#if M0 > 3
3046 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3047#endif // M0 > 3
3048#endif // defined(REINTERPRET_INPUT_AS_3D)
3049
3050 // Load values from matrix B
3051 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3052 src_addr.s1 += src1_stride_y;
3053
3054 // Accumulate
3055 acc0 = fma(b0, (half8)a0.s0, acc0);
3056#if M0 > 1
3057 acc1 = fma(b0, (half8)a1.s0, acc1);
3058#endif // M0 > 1
3059#if M0 > 2
3060 acc2 = fma(b0, (half8)a2.s0, acc2);
3061#endif // M0 > 2
3062#if M0 > 3
3063 acc3 = fma(b0, (half8)a3.s0, acc3);
3064#endif // M0 > 3
3065
3066 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3067 src_addr.s1 += src1_stride_y;
3068 acc0 = fma(b0, (half8)a0.s1, acc0);
3069#if M0 > 1
3070 acc1 = fma(b0, (half8)a1.s1, acc1);
3071#endif // M0 > 1
3072#if M0 > 2
3073 acc2 = fma(b0, (half8)a2.s1, acc2);
3074#endif // M0 > 2
3075#if M0 > 3
3076 acc3 = fma(b0, (half8)a3.s1, acc3);
3077#endif // M0 > 3
3078
3079 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3080 src_addr.s1 += src1_stride_y;
3081 acc0 = fma(b0, (half8)a0.s2, acc0);
3082#if M0 > 1
3083 acc1 = fma(b0, (half8)a1.s2, acc1);
3084#endif // M0 > 1
3085#if M0 > 2
3086 acc2 = fma(b0, (half8)a2.s2, acc2);
3087#endif // M0 > 2
3088#if M0 > 3
3089 acc3 = fma(b0, (half8)a3.s2, acc3);
3090#endif // M0 > 3
3091
3092 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3093 src_addr.s1 += src1_stride_y;
3094 acc0 = fma(b0, (half8)a0.s3, acc0);
3095#if M0 > 1
3096 acc1 = fma(b0, (half8)a1.s3, acc1);
3097#endif // M0 > 1
3098#if M0 > 2
3099 acc2 = fma(b0, (half8)a2.s3, acc2);
3100#endif // M0 > 2
3101#if M0 > 3
3102 acc3 = fma(b0, (half8)a3.s3, acc3);
3103#endif // M0 > 3
3104
3105 src_addr.s0 += 4 * sizeof(half);
3106 }
3107
3108 for(; i < (int)K; ++i)
3109 {
3110#if defined(REINTERPRET_INPUT_AS_3D)
3111 // Load values from matrix A
3112 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3113#if M0 > 1
3114 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3115#endif // M0 > 1
3116#if M0 > 2
3117 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3118#endif // M0 > 2
3119#if M0 > 3
3120 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3121#endif // M0 > 3
3122#else // defined(REINTERPRET_INPUT_AS_3D)
3123 // Load values from matrix A
3124 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3125#if M0 > 1
3126 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3127#endif // M0 > 1
3128#if M0 > 2
3129 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3130#endif // M0 > 2
3131#if M0 > 3
3132 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3133#endif // M0 > 3
3134#endif // defined(REINTERPRET_INPUT_AS_3D)
3135
3136 // Load values from matrix B
3137 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3138
3139 src_addr += (int2)(sizeof(half), src1_stride_y);
3140
3141 // Accumulate
3142 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
3143#if M0 > 1
3144 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
3145#endif // M0 > 1
3146#if M0 > 2
3147 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
3148#endif // M0 > 2
3149#if M0 > 3
3150 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
3151#endif // M0 > 3
3152 }
3153
3154 int z = get_global_id(2);
3155
SiCong Li4abc9d12020-10-28 14:19:28 +00003156 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00003157 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00003158
3159 uint4 zout = 0;
3160
3161#if defined(REINTERPRET_OUTPUT_AS_3D)
3162
3163 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3164 // in order to take into account the presence of possible cross plane paddings
3165 //
3166 // | |
3167 // | plane0 |
3168 // | |
3169 // |__________________|
3170 // |******************|
3171 // | cross_plane_pad |
3172 // |******************|
3173 // | |
3174 // | plane1 |
3175 // | |
3176 // |__________________|
3177
SiCong Li0ea50e32020-11-05 09:18:11 +00003178 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
3179 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00003180 zout = min(DEPTH_GEMM3D - 1, zout);
3181
3182 // Add offset due to the cross plane paddings
3183 zout *= (dst_cross_plane_pad * dst_stride_y);
3184
3185 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3186 // multiply dst_stride_z by DEPTH_GEMM3D
3187 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3188#else // defined(REINTERPRET_OUTPUT_AS_3D)
3189 // Add offset for batched GEMM
3190 dst_addr += z * dst_stride_z;
3191#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3192
3193 // Multiply by the weight of matrix-matrix product and store the result
3194#if defined(ALPHA)
3195 SCALE_BLOCK(M0, half, acc, ALPHA);
3196#endif // defined(ALPHA)
3197
3198 // Add beta*bias
3199#if defined(BETA)
3200 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
3201
3202#if defined(BROADCAST_BIAS)
3203 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3204
3205 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3206
3207#ifndef UNIT_BETA
3208 SCALE_BLOCK(1, half, bias, BETA);
3209#endif // UNIT_BIAS
3210
3211 // acc = acc + bias[broadcasted]
3212 ADD_BLOCK_BROADCAST(M0, acc, bias0);
3213
3214#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00003215 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
3216 PARTIAL_STORE_M0)
3217 * src2_stride_y)
3218 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00003219
3220 LOAD_BLOCK(M0, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3221
3222#ifndef UNIT_BETA
3223 SCALE_BLOCK(M0, half, bias, BETA);
3224#endif // UNIT_BIAS
3225
3226 // acc = acc + bias
3227 ADD_BLOCK(M0, acc, bias);
3228
3229#endif // defined(BROADCAST_BIAS)
3230#endif // defined(BETA)
3231
3232#if defined(ACTIVATION_TYPE)
3233 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, half, VEC_SIZE, acc, A_VAL, B_VAL);
3234#endif // defined(ACTIVATION_TYPE)
3235
3236 // Store the output block
SiCong Li0ea50e32020-11-05 09:18:11 +00003237 const bool cond_y = get_global_id(1) == 0;
3238 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
3239 STORE_BLOCK_BOUNDARY_AWARE(M0, 8, half, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00003240}
3241#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
3242
SiCong Li0ea50e32020-11-05 09:18:11 +00003243#endif // defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)