blob: 5f8b4f694eac596d443a7c1922ec939835b04ec4 [file] [log] [blame]
SiCong Li4abc9d12020-10-28 14:19:28 +00001/*
2 * Copyright (c) 2020 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "gemm_helpers.h"
25#include "repeat.h"
26
SiCong Li0ea50e32020-11-05 09:18:11 +000027#if defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
SiCong Li4abc9d12020-10-28 14:19:28 +000028/** This OpenCL kernel is optimised for Midgard. It computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
29 *
SiCong Li0ea50e32020-11-05 09:18:11 +000030 * @note The number of rows of destination matrix must be passed at compile time using -DM
31 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +000032 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
SiCong Li0ea50e32020-11-05 09:18:11 +000033 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
34 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +000035 * @note The optional alpha's value need to be passed at compile time using -DALPHA
36 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
37 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
38 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
39 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
40 *
41 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
42 * The activation function is performed after the bias addition
43 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
44 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
45 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
46 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
47 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
48 *
49 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
50 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
51 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
52 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
53 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
54 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
55 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
56 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
57 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
58 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
59 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
60 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
61 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
62 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
63 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
64 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
65 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
66 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
67 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
68 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
69 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
70 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
71 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
72 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
73 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
74 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
75 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
76 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
77 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
78 */
79__kernel void gemm_mm_interleaved_transposed_f32(IMAGE_DECLARATION(src0),
80 IMAGE_DECLARATION(src1),
81#if defined(BETA)
82 IMAGE_DECLARATION(src2),
83#endif // defined(BETA)
84 IMAGE_DECLARATION(dst),
85 uint src0_stride_z,
86 uint src1_stride_z,
87#if defined(BETA)
88 uint src2_stride_z,
89#endif //defined(BETA)
90 uint dst_stride_z
91#if defined(REINTERPRET_OUTPUT_AS_3D)
92 ,
93 uint cross_plane_pad
94#endif // REINTERPRET_OUTPUT_AS_3D
95 )
96{
97 int x = get_global_id(0) / H0;
98 int y = get_global_id(1) / V0;
99 int z = get_global_id(2);
100
101 // Offset
102 const int offset_row_a = (get_global_id(1) % V0) * 4;
103 const int offset_row_b = (get_global_id(0) % H0) * 4;
104
105 // src_addr_a = address of matrix A
106 // src_addr_b = address of matrix B
107 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
108 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
109
110#if defined(MATRIX_B_DEPTH)
111 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
112 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
113#else // defined(MATRIX_B_DEPTH)
114 src1_addr_in_bytes += z * src1_stride_z;
115#endif // defined(MATRIX_B_DEPTH)
116
117 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
118 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
119
120 // Compute end row address for matrix B
121 __global float *src_end_addr_b = src_addr_b + (src1_stride_y / sizeof(float));
122
123 src_addr_a += offset_row_a;
124 src_addr_b += offset_row_b;
125
126 // Reset accumulators
127 float4 c0 = 0.0f;
128 float4 c1 = 0.0f;
129 float4 c2 = 0.0f;
130 float4 c3 = 0.0f;
131
132 for(; src_addr_b <= (src_end_addr_b - (int)(8 * H0)); src_addr_a += 8 * V0, src_addr_b += 8 * H0)
133 {
134 // Load values from matrix A (interleaved) and matrix B (transposed)
135 float4 a0 = vload4(0, src_addr_a);
136 float4 b0 = vload4(0, src_addr_b);
137
138 c0 += (float4)a0.s0 * b0;
139 c1 += (float4)a0.s1 * b0;
140 c2 += (float4)a0.s2 * b0;
141 c3 += (float4)a0.s3 * b0;
142
143 // Load values from matrix A (interleaved) and matrix B (transposed)
144 a0 = vload4(0, src_addr_a + 4 * V0);
145 b0 = vload4(0, src_addr_b + 4 * H0);
146
147 c0 += (float4)a0.s0 * b0;
148 c1 += (float4)a0.s1 * b0;
149 c2 += (float4)a0.s2 * b0;
150 c3 += (float4)a0.s3 * b0;
151 }
152
153 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 4 * H0)
154 {
155 // Load values from matrix A (interleaved) and matrix B (transposed)
156 float4 a0 = vload4(0, src_addr_a);
157 float4 b0 = vload4(0, src_addr_b);
158
159 c0 += (float4)a0.s0 * b0;
160 c1 += (float4)a0.s1 * b0;
161 c2 += (float4)a0.s2 * b0;
162 c3 += (float4)a0.s3 * b0;
163 }
164
165 // Compute destination address
166 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
167
168 // Compute dst address
169 __global uchar *dst_addr = offset(&dst, 0, 0);
170
171 uint4 zout = 0;
172
173#if defined(REINTERPRET_OUTPUT_AS_3D)
174 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
175 // in order to take into account the presence of possible cross plane paddings
176 //
177 // | |
178 // | plane0 |
179 // | |
180 // |__________________|
181 // |******************|
182 // | cross_plane_pad |
183 // |******************|
184 // | |
185 // | plane1 |
186 // | |
187 // |__________________|
188
189 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
190 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
191 zout = min(DEPTH_GEMM3D - 1, zout);
192
193 // Add offset due to the cross plane paddings
194 zout *= (cross_plane_pad * dst_stride_y);
195
196 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
197 // multiply dst_stride_z by DEPTH_GEMM3D
198 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
199#else // defined(REINTERPRET_OUTPUT_AS_3D)
200 // Add offset for batched GEMM
201 dst_addr += z * dst_stride_z;
202#endif // defined(REINTERPRET_OUTPUT_AS_3D)
203
204 // Multiply by the weight of matrix-matrix product and store the result
205#if defined(ALPHA)
206 SCALE_BLOCK(4, float, c, ALPHA);
207#endif // defined(ALPHA)
208
209 // Add beta*bias
210#if defined(BETA)
211 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
212
213#if defined(BROADCAST_BIAS)
214 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
215
216 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
217
218#ifndef UNIT_BETA
219 SCALE_BLOCK(1, float, bias, BETA);
220#endif // UNIT_BIAS
221
222 // c = c + bias[broadcasted]
223 ADD_BLOCK_BROADCAST(4, c, bias0);
224
225#else // defined(BROADCAST_BIAS)
226 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
227 2) * src2_stride_z;
228
229 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
230
231#ifndef UNIT_BETA
232 SCALE_BLOCK(4, float, bias, BETA);
233#endif // UNIT_BIAS
234
235 // c = c + bias
236 ADD_BLOCK(4, c, bias);
237
238#endif // defined(BROADCAST_BIAS)
239#endif // defined(BETA)
240
241#if defined(ACTIVATION_TYPE)
242 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, VEC_SIZE, c, A_VAL, B_VAL);
243#endif // defined(ACTIVATION_TYPE)
244
245 // Store 4x4 block
SiCong Li0ea50e32020-11-05 09:18:11 +0000246 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
247 const bool cond_x = ((get_global_id(0) + 1) * 4 >= N);
248 STORE_BLOCK_BOUNDARY_AWARE(4, 4, float, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +0000249}
250
251/** This OpenCL kernel is optimized for Bifrost and tt computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
252 *
SiCong Li0ea50e32020-11-05 09:18:11 +0000253 * @note The number of rows of destination matrix must be passed at compile time using -DM
254 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +0000255 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
SiCong Li0ea50e32020-11-05 09:18:11 +0000256 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
257 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +0000258 * @note The optional alpha's value need to be passed at compile time using -DALPHA
259 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
260 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
SiCong Li4abc9d12020-10-28 14:19:28 +0000261 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
262 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
263 *
264 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
265 * The activation function is performed after the bias addition
266 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
267 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
268 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
269 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
270 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
271 *
272 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
273 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
274 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
276 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
278 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
279 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
280 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
281 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
282 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
283 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
284 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
285 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
286 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
287 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
288 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
289 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
290 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
291 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
292 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
293 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
294 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
295 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
296 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
297 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
298 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
299 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
300 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
301 */
302__kernel void gemm_mm_interleaved_transposed_f32_bifrost(IMAGE_DECLARATION(src0),
303 IMAGE_DECLARATION(src1),
304#if defined(BETA)
305 IMAGE_DECLARATION(src2),
306#endif // defined(BETA)
307 IMAGE_DECLARATION(dst),
308 uint src0_stride_z,
309 uint src1_stride_z,
310#if defined(BETA)
311 uint src2_stride_z,
312#endif //defined(BETA)
313 uint dst_stride_z
314#if defined(REINTERPRET_OUTPUT_AS_3D)
315 ,
316 uint cross_plane_pad
317#endif // REINTERPRET_OUTPUT_AS_3D
318 )
319{
320 int x = get_global_id(0) / H0;
321 int y = get_global_id(1) / V0;
322 int z = get_global_id(2);
323
324 // Offset
325 const int offset_row_a = (get_global_id(1) % V0) * 4;
326 const int offset_row_b = (get_global_id(0) % H0) * 4;
327
328 // src_addr_a = address of matrix A
329 // src_addr_b = address of matrix B
330 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
331 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
332
333#if defined(MATRIX_B_DEPTH)
334 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
335 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
336#else // defined(MATRIX_B_DEPTH)
337 src1_addr_in_bytes += z * src1_stride_z;
338#endif // defined(MATRIX_B_DEPTH)
339
340 __global float *src_addr_a = (__global float *)(src0_ptr + src0_addr_in_bytes);
341 __global float *src_addr_b = (__global float *)(src1_ptr + src1_addr_in_bytes);
342
343 src_addr_a += offset_row_a;
344 src_addr_b += offset_row_b;
345
346 // Reset accumulators
347 float4 c0 = 0.0f;
348 float4 c1 = 0.0f;
349 float4 c2 = 0.0f;
350 float4 c3 = 0.0f;
351
352 int i = 0;
353 for(; i <= (int)(K - 4); i += 4)
354 {
355 // Load values from matrix A (interleaved) and matrix B (transposed)
356 float4 a0 = vload4(0, src_addr_a);
357 float4 b0 = vload4(0, src_addr_b);
358
359 src_addr_a += 4 * V0;
360 src_addr_b += 4 * H0;
361
362 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
363 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
364 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
365 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
366
367 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
368 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
369 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
370 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
371
372 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
373 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
374 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
375 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
376
377 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
378 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
379 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
380 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
381
382 // Load values from matrix A (interleaved) and matrix B (transposed)
383 a0 = vload4(0, src_addr_a);
384 b0 = vload4(0, src_addr_b);
385
386 src_addr_a += 4 * V0;
387 src_addr_b += 4 * H0;
388
389 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
390 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
391 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
392 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
393
394 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
395 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
396 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
397 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
398
399 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
400 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
401 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
402 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
403
404 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
405 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
406 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
407 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
408
409 // Load values from matrix A (interleaved) and matrix B (transposed)
410 a0 = vload4(0, src_addr_a);
411 b0 = vload4(0, src_addr_b);
412
413 src_addr_a += 4 * V0;
414 src_addr_b += 4 * H0;
415
416 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
417 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
418 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
419 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
420
421 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
422 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
423 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
424 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
425
426 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
427 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
428 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
429 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
430
431 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
432 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
433 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
434 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
435
436 // Load values from matrix A (interleaved) and matrix B (transposed)
437 a0 = vload4(0, src_addr_a);
438 b0 = vload4(0, src_addr_b);
439
440 src_addr_a += 4 * V0;
441 src_addr_b += 4 * H0;
442
443 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
444 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
445 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
446 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
447
448 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
449 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
450 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
451 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
452
453 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
454 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
455 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
456 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
457
458 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
459 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
460 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
461 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
462 }
463
464 for(; i < (int)K; ++i)
465 {
466 // Load values from matrix A (interleaved) and matrix B (transposed)
467 float4 a0 = vload4(0, src_addr_a);
468 float4 b0 = vload4(0, src_addr_b);
469
470 src_addr_a += 4 * V0;
471 src_addr_b += 4 * H0;
472
473 c0.s0 = fma(a0.s0, b0.s0, c0.s0);
474 c0.s1 = fma(a0.s0, b0.s1, c0.s1);
475 c0.s2 = fma(a0.s0, b0.s2, c0.s2);
476 c0.s3 = fma(a0.s0, b0.s3, c0.s3);
477
478 c1.s0 = fma(a0.s1, b0.s0, c1.s0);
479 c1.s1 = fma(a0.s1, b0.s1, c1.s1);
480 c1.s2 = fma(a0.s1, b0.s2, c1.s2);
481 c1.s3 = fma(a0.s1, b0.s3, c1.s3);
482
483 c2.s0 = fma(a0.s2, b0.s0, c2.s0);
484 c2.s1 = fma(a0.s2, b0.s1, c2.s1);
485 c2.s2 = fma(a0.s2, b0.s2, c2.s2);
486 c2.s3 = fma(a0.s2, b0.s3, c2.s3);
487
488 c3.s0 = fma(a0.s3, b0.s0, c3.s0);
489 c3.s1 = fma(a0.s3, b0.s1, c3.s1);
490 c3.s2 = fma(a0.s3, b0.s2, c3.s2);
491 c3.s3 = fma(a0.s3, b0.s3, c3.s3);
492 }
493
494 // Compute destination address
495 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
496
497 // Compute dst address
498 __global uchar *dst_addr = offset(&dst, 0, 0);
499
500 uint4 zout = 0;
501
502#if defined(REINTERPRET_OUTPUT_AS_3D)
503 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
504 // in order to take into account the presence of possible cross plane paddings
505 //
506 // | |
507 // | plane0 |
508 // | |
509 // |__________________|
510 // |******************|
511 // | cross_plane_pad |
512 // |******************|
513 // | |
514 // | plane1 |
515 // | |
516 // |__________________|
517
518 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
519 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
520 zout = min(DEPTH_GEMM3D - 1, zout);
521
522 // Add offset due to the cross plane paddings
523 zout *= (cross_plane_pad * dst_stride_y);
524
525 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
526 // multiply dst_stride_z by DEPTH_GEMM3D
527 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
528#else // defined(REINTERPRET_OUTPUT_AS_3D)
529 // Add offset for batched GEMM
530 dst_addr += z * dst_stride_z;
531#endif // defined(REINTERPRET_OUTPUT_AS_3D)
532
533 // Multiply by the weight of matrix-matrix product and store the result
534#if defined(ALPHA)
535 SCALE_BLOCK(4, float, c, ALPHA);
536#endif // defined(ALPHA)
537
538 // Add beta*bias
539#if defined(BETA)
540 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
541
542#if defined(BROADCAST_BIAS)
543 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
544
545 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
546
547#ifndef UNIT_BETA
548 SCALE_BLOCK(1, float, bias, BETA);
549#endif // UNIT_BIAS
550
551 // c = c + bias[broadcasted]
552 ADD_BLOCK_BROADCAST(4, c, bias0);
553
554#else // defined(BROADCAST_BIAS)
555 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
556 2) * src2_stride_z;
557
558 LOAD_BLOCK(4, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
559
560#ifndef UNIT_BETA
561 SCALE_BLOCK(4, float, bias, BETA);
562#endif // UNIT_BIAS
563
564 // c = c + bias
565 ADD_BLOCK(4, c, bias);
566
567#endif // defined(BROADCAST_BIAS)
568#endif // defined(BETA)
569
570#if defined(ACTIVATION_TYPE)
571 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, float, VEC_SIZE, c, A_VAL, B_VAL);
572#endif // defined(ACTIVATION_TYPE)
573
574 // Store 4x4 block
SiCong Li0ea50e32020-11-05 09:18:11 +0000575 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
576 const bool cond_x = ((get_global_id(0) + 1) * 4 >= N);
577 STORE_BLOCK_BOUNDARY_AWARE(4, 4, float, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +0000578}
579
580#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
581/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
582 *
SiCong Li0ea50e32020-11-05 09:18:11 +0000583 * @note The number of rows of destination matrix must be passed at compile time using -DM
584 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +0000585 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
SiCong Li0ea50e32020-11-05 09:18:11 +0000586 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
587 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +0000588 * @note The optional alpha's value need to be passed at compile time using -DALPHA
589 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
590 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
591 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
592 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
593 *
594 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
595 * The activation function is performed after the bias addition
596 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
597 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
598 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
599 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
600 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
601 *
602 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
603 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
604 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
605 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
606 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
607 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
608 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
609 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
610 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
611 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
612 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
613 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
614 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
615 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
616 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
617 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
618 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
619 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
620 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
621 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
622 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
623 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
624 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
625 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
626 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
627 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
628 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
629 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
630 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
631 */
632__kernel void gemm_mm_interleaved_transposed_f16(IMAGE_DECLARATION(src0),
633 IMAGE_DECLARATION(src1),
634#if defined(BETA)
635 IMAGE_DECLARATION(src2),
636#endif // defined(BETA)
637 IMAGE_DECLARATION(dst),
638 uint src0_stride_z,
639 uint src1_stride_z,
640#if defined(BETA)
641 uint src2_stride_z,
642#endif //defined(BETA)
643 uint dst_stride_z
644#if defined(REINTERPRET_OUTPUT_AS_3D)
645 ,
646 uint cross_plane_pad
647#endif // REINTERPRET_OUTPUT_AS_3D
648 )
649{
650 int x = get_global_id(0) / H0;
651 int y = get_global_id(1) / V0;
652 int z = get_global_id(2);
653
654 // Offset
655 const int offset_row_a = (get_global_id(1) % V0) * 4;
656 const int offset_row_b = (get_global_id(0) % H0) * 8;
657
658 // src_addr_a = address of matrix A
659 // src_addr_b = address of matrix B
660 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
661 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
662
663#if defined(MATRIX_B_DEPTH)
664 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
665 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
666#else // defined(MATRIX_B_DEPTH)
667 src1_addr_in_bytes += z * src1_stride_z;
668#endif // defined(MATRIX_B_DEPTH)
669
670 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
671 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
672
673 // Compute end row address for matrix B
674 __global half *src_end_addr_b = src_addr_b + (src1_stride_y / sizeof(half));
675
676 src_addr_a += offset_row_a;
677 src_addr_b += offset_row_b;
678
679 // Reset accumulators
680 half8 c0 = 0.0f;
681 half8 c1 = 0.0f;
682 half8 c2 = 0.0f;
683 half8 c3 = 0.0f;
684
685 for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0)
686 {
687 // Load values from matrix A (interleaved) and matrix B (transposed)
688 half4 a0 = vload4(0, src_addr_a);
689 half8 b0 = vload8(0, src_addr_b);
690
691 c0 += (half8)a0.s0 * b0;
692 c1 += (half8)a0.s1 * b0;
693 c2 += (half8)a0.s2 * b0;
694 c3 += (half8)a0.s3 * b0;
695
696 // Load values from matrix A (interleaved) and matrix B (transposed)
697 a0 = vload4(0, src_addr_a + 4 * V0);
698 b0 = vload8(0, src_addr_b + 8 * H0);
699
700 c0 += (half8)a0.s0 * b0;
701 c1 += (half8)a0.s1 * b0;
702 c2 += (half8)a0.s2 * b0;
703 c3 += (half8)a0.s3 * b0;
704 }
705
706 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0)
707 {
708 // Load values from matrix A (interleaved) and matrix B (transposed)
709 half4 a0 = vload4(0, src_addr_a);
710 half8 b0 = vload8(0, src_addr_b);
711
712 c0 += (half8)a0.s0 * b0;
713 c1 += (half8)a0.s1 * b0;
714 c2 += (half8)a0.s2 * b0;
715 c3 += (half8)a0.s3 * b0;
716 }
717
718 // Compute destination address
719 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
720
721 // Compute dst address
722 __global uchar *dst_addr = offset(&dst, 0, 0);
723
724 uint4 zout = 0;
725
726#if defined(REINTERPRET_OUTPUT_AS_3D)
727 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
728 // in order to take into account the presence of possible cross plane paddings
729 //
730 // | |
731 // | plane0 |
732 // | |
733 // |__________________|
734 // |******************|
735 // | cross_plane_pad |
736 // |******************|
737 // | |
738 // | plane1 |
739 // | |
740 // |__________________|
741
742 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
743 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
744 zout = min(DEPTH_GEMM3D - 1, zout);
745
746 // Add offset due to the cross plane paddings
747 zout *= (cross_plane_pad * dst_stride_y);
748
749 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
750 // multiply dst_stride_z by DEPTH_GEMM3D
751 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
752#else // defined(REINTERPRET_OUTPUT_AS_3D)
753 // Add offset for batched GEMM
754 dst_addr += z * dst_stride_z;
755#endif // defined(REINTERPRET_OUTPUT_AS_3D)
756
757 // Multiply by the weight of matrix-matrix product and store the result
758#if defined(ALPHA)
759 SCALE_BLOCK(4, half, c, ALPHA);
760#endif // defined(ALPHA)
761
762 // Add beta*bias
763#if defined(BETA)
764 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
765
766#if defined(BROADCAST_BIAS)
767 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
768
769 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
770
771#ifndef UNIT_BETA
772 SCALE_BLOCK(1, half, bias, BETA);
773#endif // UNIT_BIAS
774
775 // c = c + bias[broadcasted]
776 ADD_BLOCK_BROADCAST(4, c, bias0);
777
778#else // defined(BROADCAST_BIAS)
779
780 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
781 2) * src2_stride_z;
782
783 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
784
785#ifndef UNIT_BETA
786 SCALE_BLOCK(4, half, bias, BETA);
787#endif // UNIT_BIAS
788
789 // c = c + bias
790 ADD_BLOCK(4, c, bias);
791
792#endif // defined(BROADCAST_BIAS)
793#endif // defined(BETA)
794
795#if defined(ACTIVATION_TYPE)
796 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c, A_VAL, B_VAL);
797#endif // defined(ACTIVATION_TYPE)
798
799 // Store 4x8 block
SiCong Li0ea50e32020-11-05 09:18:11 +0000800 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
801 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
802 STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +0000803}
804
805/** This OpenCL kernel computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1) while accumulating the result in a 32 floating point variable.
806 *
SiCong Li0ea50e32020-11-05 09:18:11 +0000807 * @note The number of rows of destination matrix must be passed at compile time using -DM
808 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +0000809 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
SiCong Li0ea50e32020-11-05 09:18:11 +0000810 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
811 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +0000812 * @note The optional alpha's value need to be passed at compile time using -DALPHA
813 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
814 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
815 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
816 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
817 *
818 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
819 * The activation function is performed after the bias addition
820 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
821 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
822 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
823 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
824 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
825 *
826 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
827 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
828 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
829 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
830 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
831 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
832 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
833 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
834 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
835 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
836 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
837 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
838 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
839 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
840 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
841 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
842 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
843 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
844 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
845 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
846 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
847 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
848 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
849 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
850 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
851 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
852 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
853 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
854 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
855 */
856__kernel void gemm_mm_interleaved_transposed_f16_acc32(IMAGE_DECLARATION(src0),
857 IMAGE_DECLARATION(src1),
858#if defined(BETA)
859 IMAGE_DECLARATION(src2),
860#endif // defined(BETA)
861 IMAGE_DECLARATION(dst),
862 uint src0_stride_z,
863 uint src1_stride_z,
864#if defined(BETA)
865 uint src2_stride_z,
866#endif //defined(BETA)
867 uint dst_stride_z
868#if defined(REINTERPRET_OUTPUT_AS_3D)
869 ,
870 uint cross_plane_pad
871#endif // REINTERPRET_OUTPUT_AS_3D
872 )
873{
874 int x = get_global_id(0) / H0;
875 int y = get_global_id(1) / V0;
876 int z = get_global_id(2);
877
878 // Offset
879 const int offset_row_a = (get_global_id(1) % V0) * 4;
880 const int offset_row_b = (get_global_id(0) % H0) * 8;
881
882 // src_addr_a = address of matrix A
883 // src_addr_b = address of matrix B
884 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
885 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
886
887#if defined(MATRIX_B_DEPTH)
888 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
889 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
890#else // defined(MATRIX_B_DEPTH)
891 src1_addr_in_bytes += z * src1_stride_z;
892#endif // defined(MATRIX_B_DEPTH)
893
894 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
895 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
896
897 // Compute end row address for matrix B
898 __global half *src_end_addr_b = src_addr_b + (src1_stride_y / sizeof(half));
899
900 src_addr_a += offset_row_a;
901 src_addr_b += offset_row_b;
902
903 // Reset accumulators
904 float8 c0 = 0.0f;
905 float8 c1 = 0.0f;
906 float8 c2 = 0.0f;
907 float8 c3 = 0.0f;
908
909 for(; src_addr_b <= (src_end_addr_b - (int)(16 * H0)); src_addr_a += 8 * V0, src_addr_b += 16 * H0)
910 {
911 // Load values from matrix A (interleaved) and matrix B (transposed)
912 float4 a0 = convert_float4(vload4(0, src_addr_a));
913 float8 b0 = convert_float8(vload8(0, src_addr_b));
914
915 c0 += (float8)a0.s0 * b0;
916 c1 += (float8)a0.s1 * b0;
917 c2 += (float8)a0.s2 * b0;
918 c3 += (float8)a0.s3 * b0;
919
920 // Load values from matrix A (interleaved) and matrix B (transposed)
921 a0 = convert_float4(vload4(0, src_addr_a + 4 * V0));
922 b0 = convert_float8(vload8(0, src_addr_b + 8 * H0));
923
924 c0 += (float8)a0.s0 * b0;
925 c1 += (float8)a0.s1 * b0;
926 c2 += (float8)a0.s2 * b0;
927 c3 += (float8)a0.s3 * b0;
928 }
929
930 for(; src_addr_b < src_end_addr_b; src_addr_a += 4 * V0, src_addr_b += 8 * H0)
931 {
932 // Load values from matrix A (interleaved) and matrix B (transposed)
933 float4 a0 = convert_float4(vload4(0, src_addr_a));
934 float8 b0 = convert_float8(vload8(0, src_addr_b));
935
936 c0 += (float8)a0.s0 * b0;
937 c1 += (float8)a0.s1 * b0;
938 c2 += (float8)a0.s2 * b0;
939 c3 += (float8)a0.s3 * b0;
940 }
941
942 // Compute destination address
943 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
944
945 // Compute dst address
946 __global uchar *dst_addr = offset(&dst, 0, 0);
947
948 uint4 zout = 0;
949
950#if defined(REINTERPRET_OUTPUT_AS_3D)
951 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
952 // in order to take into account the presence of possible cross plane paddings
953 //
954 // | |
955 // | plane0 |
956 // | |
957 // |__________________|
958 // |******************|
959 // | cross_plane_pad |
960 // |******************|
961 // | |
962 // | plane1 |
963 // | |
964 // |__________________|
965
966 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
967 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
968 zout = min(DEPTH_GEMM3D - 1, zout);
969
970 // Add offset due to the cross plane paddings
971 zout *= (cross_plane_pad * dst_stride_y);
972
973 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
974 // multiply dst_stride_z by DEPTH_GEMM3D
975 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
976#else // defined(REINTERPRET_OUTPUT_AS_3D)
977 // Add offset for batched GEMM
978 dst_addr += z * dst_stride_z;
979#endif // defined(REINTERPRET_OUTPUT_AS_3D)
980
981 // Multiply by the weight of matrix-matrix product and store the result
982#if defined(ALPHA)
983 SCALE_BLOCK(4, float, c, ALPHA);
984#endif // defined(ALPHA)
985
986#if defined(BETA)
987 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
988
989#if defined(BROADCAST_BIAS)
990 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
991
992 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
993
994 float8 bias_f0 = convert_float8(bias0);
995
996#ifndef UNIT_BETA
997 SCALE_BLOCK(1, float, bias_f, BETA);
998#endif // UNIT_BIAS
999
1000 // c = c + bias[broadcasted]
1001 ADD_BLOCK_BROADCAST(4, c, bias_f0);
1002
1003#else // defined(BROADCAST_BIAS)
1004 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
1005 2) * src2_stride_z;
1006
1007 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
1008
1009 float8 bias_f0 = convert_float8(bias0);
1010 float8 bias_f1 = convert_float8(bias1);
1011 float8 bias_f2 = convert_float8(bias2);
1012 float8 bias_f3 = convert_float8(bias3);
1013
1014#ifndef UNIT_BETA
1015 SCALE_BLOCK(4, float, bias_f, BETA);
1016#endif // UNIT_BIAS
1017
1018 // c = c + bias
1019 ADD_BLOCK(4, c, bias_f);
1020
1021#endif // defined(BROADCAST_BIAS)
1022#endif // defined(BETA)
1023
1024 half8 c_h0 = convert_half8(c0);
1025 half8 c_h1 = convert_half8(c1);
1026 half8 c_h2 = convert_half8(c2);
1027 half8 c_h3 = convert_half8(c3);
1028
1029#if defined(ACTIVATION_TYPE)
1030 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c_h, A_VAL, B_VAL);
1031#endif // defined(ACTIVATION_TYPE)
1032
1033 // Store 4x8 block
SiCong Li0ea50e32020-11-05 09:18:11 +00001034 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
1035 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
1036 STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c_h, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00001037}
1038
1039/** This OpenCL kernel optimized for Bifrost architectures computes the matrix multiplication between matrix A reshaped (src0) and matrix B reshaped (src1)
1040 *
SiCong Li0ea50e32020-11-05 09:18:11 +00001041 * @note The number of rows of destination matrix must be passed at compile time using -DM
1042 * @note The number of columns of the destination matrix must be passed at compile time using -DN
SiCong Li4abc9d12020-10-28 14:19:28 +00001043 * @note The number of rows of the *un-reshaped* matrix B (K) must be passed at compile time using -DK
SiCong Li0ea50e32020-11-05 09:18:11 +00001044 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
1045 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
SiCong Li4abc9d12020-10-28 14:19:28 +00001046 * @note The optional alpha's value need to be passed at compile time using -DALPHA
1047 * @note The multiplication factor for the transposition width (H0) must be passed at compile time using -DH0 (e.g. -DH0=2)
1048 * @note The multiplication factor for the height of the 4x4 interleaved block must be passed at compile time using -DV0 (e.g. -DV0=2)
1049 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
1050 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
1051 *
1052 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1053 * The activation function is performed after the bias addition
1054 * @note In case the output has to be reinterpreted as a 3D tensor (e.g. output of convolution layer), the following information must be passed at compile time:
1055 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1056 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1057 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1058 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1059 *
1060 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
1061 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1062 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1063 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1064 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1065 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1066 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1067 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1068 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1069 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1070 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1071 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1072 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1073 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1074 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
1075 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1076 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
1077 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1078 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1079 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1080 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1081 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1082 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1083 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1084 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1085 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1086 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1087 * @param[in] cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1088 */
1089__kernel void gemm_mm_interleaved_transposed_f16_bifrost(IMAGE_DECLARATION(src0),
1090 IMAGE_DECLARATION(src1),
1091#if defined(BETA)
1092 IMAGE_DECLARATION(src2),
1093#endif // defined(BETA)
1094 IMAGE_DECLARATION(dst),
1095 uint src0_stride_z,
1096 uint src1_stride_z,
1097#if defined(BETA)
1098 uint src2_stride_z,
1099#endif //defined(BETA)
1100 uint dst_stride_z
1101#if defined(REINTERPRET_OUTPUT_AS_3D)
1102 ,
1103 uint cross_plane_pad
1104#endif // REINTERPRET_OUTPUT_AS_3D
1105 )
1106{
1107 int x = get_global_id(0) / H0;
1108 int y = get_global_id(1) / V0;
1109 int z = get_global_id(2);
1110
1111 // Offset
1112 const int offset_row_a = (get_global_id(1) % V0) * 4;
1113 const int offset_row_b = (get_global_id(0) % H0) * 8;
1114
1115 // src_addr_a = address of matrix A
1116 // src_addr_b = address of matrix B
1117 int src0_addr_in_bytes = z * src0_stride_z + y * src0_stride_y + src0_offset_first_element_in_bytes;
1118 int src1_addr_in_bytes = x * src1_stride_y + src1_offset_first_element_in_bytes;
1119
1120#if defined(MATRIX_B_DEPTH)
1121 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1122 src1_addr_in_bytes += (z % MATRIX_B_DEPTH) * src1_stride_z;
1123#else // defined(MATRIX_B_DEPTH)
1124 src1_addr_in_bytes += z * src1_stride_z;
1125#endif // defined(MATRIX_B_DEPTH)
1126
1127 __global half *src_addr_a = (__global half *)(src0_ptr + src0_addr_in_bytes);
1128 __global half *src_addr_b = (__global half *)(src1_ptr + src1_addr_in_bytes);
1129
1130 src_addr_a += offset_row_a;
1131 src_addr_b += offset_row_b;
1132
1133 // Reset accumulators
1134 half8 c0 = 0.0f;
1135 half8 c1 = 0.0f;
1136 half8 c2 = 0.0f;
1137 half8 c3 = 0.0f;
1138
1139 int i = 0;
1140 for(; i <= (int)(K - 4); i += 4)
1141 {
1142#if V0 == 1
1143 // Load values from matrix A (interleaved) and matrix B (transposed)
1144 half8 a0 = vload8(0, src_addr_a);
1145 half8 b0 = vload8(0, src_addr_b);
1146
1147 src_addr_a += 8 * V0;
1148 src_addr_b += 8 * H0;
1149
1150 c0 = fma((half8)a0.s0, b0, c0);
1151 c1 = fma((half8)a0.s1, b0, c1);
1152 c2 = fma((half8)a0.s2, b0, c2);
1153 c3 = fma((half8)a0.s3, b0, c3);
1154
1155 // Load values from matrix B (transposed)
1156 b0 = vload8(0, src_addr_b);
1157
1158 src_addr_b += 8 * H0;
1159
1160 c0 = fma((half8)a0.s4, b0, c0);
1161 c1 = fma((half8)a0.s5, b0, c1);
1162 c2 = fma((half8)a0.s6, b0, c2);
1163 c3 = fma((half8)a0.s7, b0, c3);
1164
1165 // Load values from matrix A (interleaved) and matrix B (transposed)
1166 a0 = vload8(0, src_addr_a);
1167 b0 = vload8(0, src_addr_b);
1168
1169 src_addr_a += 8 * V0;
1170 src_addr_b += 8 * H0;
1171
1172 c0 = fma((half8)a0.s0, b0, c0);
1173 c1 = fma((half8)a0.s1, b0, c1);
1174 c2 = fma((half8)a0.s2, b0, c2);
1175 c3 = fma((half8)a0.s3, b0, c3);
1176
1177 // Load values from matrix B (transposed)
1178 b0 = vload8(0, src_addr_b);
1179
1180 src_addr_b += 8 * H0;
1181
1182 c0 = fma((half8)a0.s4, b0, c0);
1183 c1 = fma((half8)a0.s5, b0, c1);
1184 c2 = fma((half8)a0.s6, b0, c2);
1185 c3 = fma((half8)a0.s7, b0, c3);
1186#else // V0 == 1
1187 // Load values from matrix A (interleaved) and matrix B (transposed)
1188 half4 a0 = vload4(0, src_addr_a);
1189 half8 b0 = vload8(0, src_addr_b);
1190
1191 src_addr_a += 4 * V0;
1192 src_addr_b += 8 * H0;
1193
1194 c0 = fma((half8)a0.s0, b0, c0);
1195 c1 = fma((half8)a0.s1, b0, c1);
1196 c2 = fma((half8)a0.s2, b0, c2);
1197 c3 = fma((half8)a0.s3, b0, c3);
1198
1199 // Load values from matrix A (interleaved) and matrix B (transposed)
1200 a0 = vload4(0, src_addr_a);
1201 b0 = vload8(0, src_addr_b);
1202
1203 src_addr_a += 4 * V0;
1204 src_addr_b += 8 * H0;
1205
1206 c0 = fma((half8)a0.s0, b0, c0);
1207 c1 = fma((half8)a0.s1, b0, c1);
1208 c2 = fma((half8)a0.s2, b0, c2);
1209 c3 = fma((half8)a0.s3, b0, c3);
1210
1211 // Load values from matrix A (interleaved) and matrix B (transposed)
1212 a0 = vload4(0, src_addr_a);
1213 b0 = vload8(0, src_addr_b);
1214
1215 src_addr_a += 4 * V0;
1216 src_addr_b += 8 * H0;
1217
1218 c0 = fma((half8)a0.s0, b0, c0);
1219 c1 = fma((half8)a0.s1, b0, c1);
1220 c2 = fma((half8)a0.s2, b0, c2);
1221 c3 = fma((half8)a0.s3, b0, c3);
1222
1223 // Load values from matrix A (interleaved) and matrix B (transposed)
1224 a0 = vload4(0, src_addr_a);
1225 b0 = vload8(0, src_addr_b);
1226
1227 src_addr_a += 4 * V0;
1228 src_addr_b += 8 * H0;
1229
1230 c0 = fma((half8)a0.s0, b0, c0);
1231 c1 = fma((half8)a0.s1, b0, c1);
1232 c2 = fma((half8)a0.s2, b0, c2);
1233 c3 = fma((half8)a0.s3, b0, c3);
1234#endif // V0 == 1
1235 }
1236
1237 for(; i < (int)K; ++i)
1238 {
1239 // Load values from matrix A (interleaved) and matrix B (transposed)
1240 half4 a0 = vload4(0, src_addr_a);
1241 half8 b0 = vload8(0, src_addr_b);
1242
1243 src_addr_a += 4 * V0;
1244 src_addr_b += 8 * H0;
1245
1246 c0 = fma((half8)a0.s0, b0, c0);
1247 c1 = fma((half8)a0.s1, b0, c1);
1248 c2 = fma((half8)a0.s2, b0, c2);
1249 c3 = fma((half8)a0.s3, b0, c3);
1250 }
1251
1252 // Compute destination address
1253 Image dst = CONVERT_TO_IMAGE_STRUCT(dst);
1254
1255 // Compute dst address
1256 __global uchar *dst_addr = offset(&dst, 0, 0);
1257
1258 uint4 zout = 0;
1259
1260#if defined(REINTERPRET_OUTPUT_AS_3D)
1261 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1262 // in order to take into account the presence of possible cross plane paddings
1263 //
1264 // | |
1265 // | plane0 |
1266 // | |
1267 // |__________________|
1268 // |******************|
1269 // | cross_plane_pad |
1270 // |******************|
1271 // | |
1272 // | plane1 |
1273 // | |
1274 // |__________________|
1275
1276 // The plane (zout) is calculated dividing M (get_global_id(1) * 4) by HEIGHT_GEMM3D
1277 zout = ((uint4)(0, 1, 2, 3) + (uint4)(get_global_id(1) * 4)) / (uint4)HEIGHT_GEMM3D;
1278 zout = min(DEPTH_GEMM3D - 1, zout);
1279
1280 // Add offset due to the cross plane paddings
1281 zout *= (cross_plane_pad * dst_stride_y);
1282
1283 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1284 // multiply dst_stride_z by DEPTH_GEMM3D
1285 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1286#else // defined(REINTERPRET_OUTPUT_AS_3D)
1287 // Add offset for batched GEMM
1288 dst_addr += z * dst_stride_z;
1289#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1290
1291 // Multiply by the weight of matrix-matrix product and store the result
1292#if defined(ALPHA)
1293 SCALE_BLOCK(4, half, c, ALPHA);
1294#endif // defined(ALPHA)
1295
1296 // Add beta*bias
1297#if defined(BETA)
1298 REPEAT_VAR_INIT_TO_CONST(4, uint, zero, 0);
1299
1300#if defined(BROADCAST_BIAS)
1301 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
1302
1303 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
1304
1305#ifndef UNIT_BETA
1306 SCALE_BLOCK(1, half, bias, BETA);
1307#endif // UNIT_BIAS
1308
1309 // c = c + bias[broadcasted]
1310 ADD_BLOCK_BROADCAST(4, c, bias0);
1311
1312#else // defined(BROADCAST_BIAS)
1313 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (get_global_id(1) * (uint)4 * src2_stride_y) + get_global_id(
1314 2) * src2_stride_z;
1315
1316 LOAD_BLOCK(4, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
1317
1318#ifndef UNIT_BETA
1319 SCALE_BLOCK(4, half, bias, BETA);
1320#endif // UNIT_BIAS
1321
1322 // c = c + bias
1323 ADD_BLOCK(4, c, bias);
1324
1325#endif // defined(BROADCAST_BIAS)
1326#endif // defined(BETA)
1327
1328#if defined(ACTIVATION_TYPE)
1329 ACTIVATION_BLOCK(4, ACTIVATION_TYPE, half, VEC_SIZE, c, A_VAL, B_VAL);
1330#endif // defined(ACTIVATION_TYPE)
1331
1332 // Store 4x8 block
SiCong Li0ea50e32020-11-05 09:18:11 +00001333 const bool cond_y = ((get_global_id(1) + 1) * 4 >= M);
1334 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
1335 STORE_BLOCK_BOUNDARY_AWARE(4, 8, half, c, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00001336}
1337
1338#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
1339
SiCong Li0ea50e32020-11-05 09:18:11 +00001340#endif // defined(M) && defined(N) && defined(K) && defined(H0) && defined(V0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
SiCong Li4abc9d12020-10-28 14:19:28 +00001341
SiCong Li0ea50e32020-11-05 09:18:11 +00001342#if defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)
SiCong Li4abc9d12020-10-28 14:19:28 +00001343#if defined(DATA_TYPE)
1344#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, N0)
1345/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped.
1346 *
1347 * @note This OpenCL kernel works with floating point data types (F16/F32)
1348 * @note The floating point data type must be passed at compile time using -DDATA_TYPE (e.g. -DDATA_TYPE=float)
1349 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0
SiCong Li0ea50e32020-11-05 09:18:11 +00001350 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
1351 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
1352 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
1353 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00001354 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
1355 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
1356 *
1357 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1358 * The activation function is performed after the bias addition
1359 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1360 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1361 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1362 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1363 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1364 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1365 *
1366 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16/F32
1367 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1368 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1369 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1370 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1371 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1372 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1373 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1374 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1375 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1376 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1377 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1378 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1379 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1380 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
1381 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1382 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
1383 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1384 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1385 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1386 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1387 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1388 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1389 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1390 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1391 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1392 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1393 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1394 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1395 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements for the output tensor (only if defined REINTERPRET_OUTPUT_AS_3D)
1396 */
1397__kernel void gemm_mm_floating_point(IMAGE_DECLARATION(src0),
1398 IMAGE_DECLARATION(src1),
1399#if defined(BETA)
1400 IMAGE_DECLARATION(src2),
1401#endif // defined(BETA)
1402 IMAGE_DECLARATION(dst),
1403 uint src0_stride_z,
1404 uint src1_stride_z,
1405#if defined(BETA)
1406 uint src2_stride_z,
1407#endif //defined(BETA)
1408 uint dst_stride_z
1409#if defined(REINTERPRET_INPUT_AS_3D)
1410 ,
1411 uint src_cross_plane_pad
1412#endif // REINTERPRET_INPUT_AS_3D
1413#if defined(REINTERPRET_OUTPUT_AS_3D)
1414 ,
1415 uint dst_cross_plane_pad
1416#endif // REINTERPRET_OUTPUT_AS_3D
1417 )
1418{
1419 int idx = get_global_id(0) * N0;
1420
1421 // Compute starting address for matrix A and Matrix B
1422 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1423
1424 // Update address for the matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00001425 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00001426
1427 // Update address for the matrix B
1428 src_addr.s1 += idx * sizeof(DATA_TYPE);
1429
1430#if defined(REINTERPRET_INPUT_AS_3D)
1431 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1432 // in order to take into account the presence of possible cross plane paddings
1433 //
1434 // | |
1435 // | plane0 |
1436 // | |
1437 // |__________________|
1438 // |******************|
1439 // | cross_plane_pad |
1440 // |******************|
1441 // | |
1442 // | plane1 |
1443 // | |
1444 // |__________________|
1445
SiCong Li0ea50e32020-11-05 09:18:11 +00001446 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
1447 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00001448 zin = min(DEPTH_GEMM3D - 1, zin);
1449
1450 // Add offset due to the cross plane paddings
1451 zin *= (src_cross_plane_pad * src0_stride_y);
1452
1453 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1454 // multiply src0_stride_z by DEPTH_GEMM3D
1455 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1456
1457#else // defined(REINTERPRET_INPUT_AS_3D)
1458
1459 // Add offset for batched GEMM
1460 src_addr.s0 += get_global_id(2) * src0_stride_z;
1461
1462#endif // defined(REINTERPRET_INPUT_AS_3D)
1463
1464#if defined(MATRIX_B_DEPTH)
1465 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1466 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1467#else // defined(MATRIX_B_DEPTH)
1468 src_addr.s1 += get_global_id(2) * src1_stride_z;
1469#endif // defined(MATRIX_B_DEPTH)
1470
1471 int end_row_vec_a = src_addr.s0 + (K * sizeof(DATA_TYPE));
1472
1473 VECTOR_TYPE acc0 = 0.0f;
1474#if M0 > 1
1475 VECTOR_TYPE acc1 = 0.0f;
1476#endif // M0 > 1
1477#if M0 > 2
1478 VECTOR_TYPE acc2 = 0.0f;
1479#endif // M0 > 2
1480#if M0 > 3
1481 VECTOR_TYPE acc3 = 0.0f;
1482#endif // M0 > 3
1483
1484 for(; src_addr.s0 <= (end_row_vec_a - 2 * (int)sizeof(DATA_TYPE)); src_addr += (int2)(2 * sizeof(DATA_TYPE), 2 * src1_stride_y))
1485 {
1486#if defined(REINTERPRET_INPUT_AS_3D)
1487 // Load values from matrix A
1488 LOAD_BLOCK(M0, 2, DATA_TYPE, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
1489#else // defined(REINTERPRET_INPUT_AS_3D)
1490 // Load values from matrix A
1491 VEC_DATA_TYPE(DATA_TYPE, 2)
1492 a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1493#if M0 > 1
1494 VEC_DATA_TYPE(DATA_TYPE, 2)
1495 a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1496#endif // M0 > 1
1497#if M0 > 2
1498 VEC_DATA_TYPE(DATA_TYPE, 2)
1499 a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1500#endif // M0 > 2
1501#if M0 > 3
1502 VEC_DATA_TYPE(DATA_TYPE, 2)
1503 a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1504#endif // M0 > 3
1505#endif // defined(REINTERPRET_INPUT_AS_3D)
1506
1507 // Load values from matrix B
1508 VECTOR_TYPE b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1509 VECTOR_TYPE b1 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
1510
1511 // Accumulate
1512 acc0 += b0 * (VECTOR_TYPE)a0.s0;
1513 acc0 += b1 * (VECTOR_TYPE)a0.s1;
1514#if M0 > 1
1515 acc1 += b0 * (VECTOR_TYPE)a1.s0;
1516 acc1 += b1 * (VECTOR_TYPE)a1.s1;
1517#endif // M0 > 1
1518#if M0 > 2
1519 acc2 += b0 * (VECTOR_TYPE)a2.s0;
1520 acc2 += b1 * (VECTOR_TYPE)a2.s1;
1521#endif // M0 > 2
1522#if M0 > 3
1523 acc3 += b0 * (VECTOR_TYPE)a3.s0;
1524 acc3 += b1 * (VECTOR_TYPE)a3.s1;
1525#endif // M0 > 3
1526 }
1527
1528 for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(sizeof(DATA_TYPE), src1_stride_y))
1529 {
1530#if defined(REINTERPRET_INPUT_AS_3D)
1531 // Load values from matrix A
1532 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1533#if M0 > 1
1534 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1535#endif // M0 > 1
1536#if M0 > 2
1537 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1538#endif // M0 > 2
1539#if M0 > 3
1540 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1541#endif // M0 > 3
1542#else // defined(REINTERPRET_INPUT_AS_3D)
1543 // Load values from matrix A
1544 DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1545#if M0 > 1
1546 DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1547#endif // M0 > 1
1548#if M0 > 2
1549 DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1550#endif // M0 > 2
1551#if M0 > 3
1552 DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1553#endif // M0 > 3
1554#endif // defined(REINTERPRET_INPUT_AS_3D)
1555
1556 // Load values from matrix B
1557 VECTOR_TYPE b0 = VLOAD(N0)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
1558
1559 // Accumulate
1560 acc0 += b0 * (VECTOR_TYPE)a0;
1561#if M0 > 1
1562 acc1 += b0 * (VECTOR_TYPE)a1;
1563#endif // M0 > 1
1564#if M0 > 2
1565 acc2 += b0 * (VECTOR_TYPE)a2;
1566#endif // M0 > 2
1567#if M0 > 3
1568 acc3 += b0 * (VECTOR_TYPE)a3;
1569#endif // M0 > 3
1570 }
1571
1572 int z = get_global_id(2);
1573
SiCong Li4abc9d12020-10-28 14:19:28 +00001574 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00001575 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
1576 PARTIAL_STORE_M0)
1577 * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00001578
1579 uint4 zout = 0;
1580
1581#if defined(REINTERPRET_OUTPUT_AS_3D)
1582
1583 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
1584 // in order to take into account the presence of possible cross plane paddings
1585 //
1586 // | |
1587 // | plane0 |
1588 // | |
1589 // |__________________|
1590 // |******************|
1591 // | cross_plane_pad |
1592 // |******************|
1593 // | |
1594 // | plane1 |
1595 // | |
1596 // |__________________|
1597
SiCong Li0ea50e32020-11-05 09:18:11 +00001598 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
1599 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00001600 zout = min(DEPTH_GEMM3D - 1, zout);
1601
1602 // Add offset due to the cross plane paddings
1603 zout *= (dst_cross_plane_pad * dst_stride_y);
1604
1605 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1606 // multiply dst_stride_z by DEPTH_GEMM3D
1607 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
1608#else // defined(REINTERPRET_OUTPUT_AS_3D)
1609 // Add offset for batched GEMM
1610 dst_addr += z * dst_stride_z;
1611#endif // defined(REINTERPRET_OUTPUT_AS_3D)
1612
1613 // Multiply by the weight of matrix-matrix product and store the result
1614#if defined(ALPHA)
1615 SCALE_BLOCK(M0, DATA_TYPE, acc, ALPHA);
1616#endif // defined(ALPHA)
1617
1618 // Add beta*bias
1619#if defined(BETA)
1620 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
1621
1622#if defined(BROADCAST_BIAS)
1623 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE));
1624
1625 LOAD_BLOCK(1, N0, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
1626
1627#ifndef UNIT_BETA
1628 SCALE_BLOCK(1, DATA_TYPE, bias, BETA);
1629#endif // UNIT_BIAS
1630
1631 // c = c + bias[broadcasted]
1632 ADD_BLOCK_BROADCAST(M0, acc, bias0);
1633
1634#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00001635 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)N0 * sizeof(DATA_TYPE)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
1636 PARTIAL_STORE_M0)
1637 * src2_stride_y)
1638 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00001639
1640 LOAD_BLOCK(M0, N0, DATA_TYPE, bias, src2_addr, 0, src2_stride_y, zero);
1641
1642#ifndef UNIT_BETA
1643 SCALE_BLOCK(M0, DATA_TYPE, bias, BETA);
1644#endif // UNIT_BIAS
1645
1646 // c = c + bias
1647 ADD_BLOCK(M0, acc, bias);
1648
1649#endif // defined(BROADCAST_BIAS)
1650#endif // defined(BETA)
1651
1652#if defined(ACTIVATION_TYPE)
1653 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE, acc, A_VAL, B_VAL);
1654#endif // defined(ACTIVATION_TYPE)
1655
1656 // Store output block
SiCong Li0ea50e32020-11-05 09:18:11 +00001657 const bool cond_y = get_global_id(1) == 0;
1658 const bool cond_x = ((get_global_id(0) + 1) * N0 >= N);
1659 STORE_BLOCK_BOUNDARY_AWARE(M0, N0, DATA_TYPE, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00001660}
1661#endif // defined(DATA_TYPE)
1662
1663/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
1664 *
1665 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
1666 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
SiCong Li0ea50e32020-11-05 09:18:11 +00001667 * @note This kernel processed a fixed number of elements along x: -DN0=4.
1668 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
1669 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
1670 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
1671 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00001672 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
1673 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
1674 *
1675 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
1676 * The activation function is performed after the bias addition
1677 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
1678 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
1679 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
1680 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
1681 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
1682 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
1683 *
1684 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
1685 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
1686 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1687 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
1688 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1689 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
1690 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
1691 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
1692 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1693 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
1694 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1695 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
1696 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
1697 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
1698 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
1699 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
1700 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
1701 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
1702 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
1703 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
1704 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
1705 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
1706 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
1707 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
1708 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
1709 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
1710 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
1711 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1712 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
1713 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
1714 */
1715__kernel void gemm_mm_floating_point_f32_bifrost(IMAGE_DECLARATION(src0),
1716 IMAGE_DECLARATION(src1),
1717#if defined(BETA)
1718 IMAGE_DECLARATION(src2),
1719#endif // defined(BETA)
1720 IMAGE_DECLARATION(dst),
1721 uint src0_stride_z,
1722 uint src1_stride_z,
1723#if defined(BETA)
1724 uint src2_stride_z,
1725#endif //defined(BETA)
1726 uint dst_stride_z
1727#if defined(REINTERPRET_INPUT_AS_3D)
1728 ,
1729 uint src_cross_plane_pad
1730#endif // REINTERPRET_INPUT_AS_3D
1731#if defined(REINTERPRET_OUTPUT_AS_3D)
1732 ,
1733 uint dst_cross_plane_pad
1734#endif // REINTERPRET_OUTPUT_AS_3D
1735 )
1736{
1737 int idx = get_global_id(0) * N0;
1738
1739 // Compute starting address for matrix A and matrix B
1740 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
1741
1742 // Update address for matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00001743 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00001744
1745 // Update address for matrix B
1746 src_addr.s1 += idx * sizeof(float);
1747
1748#if defined(REINTERPRET_INPUT_AS_3D)
1749 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
1750 // in order to take into account the presence of possible cross plane paddings
1751 //
1752 // | |
1753 // | plane0 |
1754 // | |
1755 // |__________________|
1756 // |******************|
1757 // | cross_plane_pad |
1758 // |******************|
1759 // | |
1760 // | plane1 |
1761 // | |
1762 // |__________________|
1763
SiCong Li0ea50e32020-11-05 09:18:11 +00001764 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
1765 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00001766 zin = min(DEPTH_GEMM3D - 1, zin);
1767
1768 // Add offset due to the cross plane paddings
1769 zin *= (src_cross_plane_pad * src0_stride_y);
1770
1771 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
1772 // multiply src0_stride_z by DEPTH_GEMM3D
1773 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
1774
1775#else // defined(REINTERPRET_INPUT_AS_3D)
1776
1777 // Add offset for batched GEMM
1778 src_addr.s0 += get_global_id(2) * src0_stride_z;
1779
1780#endif // defined(REINTERPRET_INPUT_AS_3D)
1781
1782#if defined(MATRIX_B_DEPTH)
1783 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
1784 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
1785#else // defined(MATRIX_B_DEPTH)
1786 src_addr.s1 += get_global_id(2) * src1_stride_z;
1787#endif // defined(MATRIX_B_DEPTH)
1788
1789 // Initialize accumulators
1790 float4 acc0 = 0.0f;
1791
1792#if M0 > 1
1793 float4 acc1 = 0.0f;
1794#endif // M0 > 1
1795
1796#if M0 > 2
1797 float4 acc2 = 0.0f;
1798#endif // M0 > 2
1799
1800#if M0 > 3
1801 float4 acc3 = 0.0f;
1802#endif // M0 > 3
1803
1804 // A and B src indices get incremented at the same time.
1805 int i = 0;
1806 for(; i <= ((int)K - 4); i += 4)
1807 {
1808#if defined(REINTERPRET_INPUT_AS_3D)
1809 // Load values from matrix A and matrix B
1810 LOAD_BLOCK(M0, 4, float, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
1811#else // defined(REINTERPRET_INPUT_AS_3D)
1812 // Load values from matrix A and matrix B
1813 float4 a0 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1814#if M0 > 1
1815 float4 a1 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1816#endif // M0 > 1
1817#if M0 > 2
1818 float4 a2 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1819#endif // M0 > 2
1820#if M0 > 3
1821 float4 a3 = vload4(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1822#endif // M0 > 3
1823#endif // defined(REINTERPRET_INPUT_AS_3D)
1824
1825 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1826 src_addr.s1 += src1_stride_y;
1827
1828 // Multiply and accumulate
1829 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
1830 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
1831 acc0.s2 = fma(a0.s0, b0.s2, acc0.s2);
1832 acc0.s3 = fma(a0.s0, b0.s3, acc0.s3);
1833
1834#if M0 > 1
1835
1836 acc1.s0 = fma(a1.s0, b0.s0, acc1.s0);
1837 acc1.s1 = fma(a1.s0, b0.s1, acc1.s1);
1838 acc1.s2 = fma(a1.s0, b0.s2, acc1.s2);
1839 acc1.s3 = fma(a1.s0, b0.s3, acc1.s3);
1840
1841#endif // M0 > 1
1842#if M0 > 2
1843
1844 acc2.s0 = fma(a2.s0, b0.s0, acc2.s0);
1845 acc2.s1 = fma(a2.s0, b0.s1, acc2.s1);
1846 acc2.s2 = fma(a2.s0, b0.s2, acc2.s2);
1847 acc2.s3 = fma(a2.s0, b0.s3, acc2.s3);
1848
1849#endif // M0 > 2
1850#if M0 > 3
1851
1852 acc3.s0 = fma(a3.s0, b0.s0, acc3.s0);
1853 acc3.s1 = fma(a3.s0, b0.s1, acc3.s1);
1854 acc3.s2 = fma(a3.s0, b0.s2, acc3.s2);
1855 acc3.s3 = fma(a3.s0, b0.s3, acc3.s3);
1856#endif // M0 > 3
1857
1858 // Load values from matrix A and matrix B
1859 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1860 src_addr.s1 += src1_stride_y;
1861
1862 // Multiply and accumulate
1863 acc0.s0 = fma(a0.s1, b0.s0, acc0.s0);
1864 acc0.s1 = fma(a0.s1, b0.s1, acc0.s1);
1865 acc0.s2 = fma(a0.s1, b0.s2, acc0.s2);
1866 acc0.s3 = fma(a0.s1, b0.s3, acc0.s3);
1867
1868#if M0 > 1
1869
1870 acc1.s0 = fma(a1.s1, b0.s0, acc1.s0);
1871 acc1.s1 = fma(a1.s1, b0.s1, acc1.s1);
1872 acc1.s2 = fma(a1.s1, b0.s2, acc1.s2);
1873 acc1.s3 = fma(a1.s1, b0.s3, acc1.s3);
1874
1875#endif // M0 > 1
1876#if M0 > 2
1877
1878 acc2.s0 = fma(a2.s1, b0.s0, acc2.s0);
1879 acc2.s1 = fma(a2.s1, b0.s1, acc2.s1);
1880 acc2.s2 = fma(a2.s1, b0.s2, acc2.s2);
1881 acc2.s3 = fma(a2.s1, b0.s3, acc2.s3);
1882
1883#endif // M0 > 2
1884#if M0 > 3
1885
1886 acc3.s0 = fma(a3.s1, b0.s0, acc3.s0);
1887 acc3.s1 = fma(a3.s1, b0.s1, acc3.s1);
1888 acc3.s2 = fma(a3.s1, b0.s2, acc3.s2);
1889 acc3.s3 = fma(a3.s1, b0.s3, acc3.s3);
1890#endif // M0 > 3
1891
1892 // Load values from matrix A and matrix B
1893 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1894 src_addr.s1 += src1_stride_y;
1895
1896 // Multiply and accumulate
1897 acc0.s0 = fma(a0.s2, b0.s0, acc0.s0);
1898 acc0.s1 = fma(a0.s2, b0.s1, acc0.s1);
1899 acc0.s2 = fma(a0.s2, b0.s2, acc0.s2);
1900 acc0.s3 = fma(a0.s2, b0.s3, acc0.s3);
1901
1902#if M0 > 1
1903
1904 acc1.s0 = fma(a1.s2, b0.s0, acc1.s0);
1905 acc1.s1 = fma(a1.s2, b0.s1, acc1.s1);
1906 acc1.s2 = fma(a1.s2, b0.s2, acc1.s2);
1907 acc1.s3 = fma(a1.s2, b0.s3, acc1.s3);
1908
1909#endif // M0 > 1
1910#if M0 > 2
1911
1912 acc2.s0 = fma(a2.s2, b0.s0, acc2.s0);
1913 acc2.s1 = fma(a2.s2, b0.s1, acc2.s1);
1914 acc2.s2 = fma(a2.s2, b0.s2, acc2.s2);
1915 acc2.s3 = fma(a2.s2, b0.s3, acc2.s3);
1916
1917#endif // M0 > 2
1918#if M0 > 3
1919
1920 acc3.s0 = fma(a3.s2, b0.s0, acc3.s0);
1921 acc3.s1 = fma(a3.s2, b0.s1, acc3.s1);
1922 acc3.s2 = fma(a3.s2, b0.s2, acc3.s2);
1923 acc3.s3 = fma(a3.s2, b0.s3, acc3.s3);
1924#endif // M0 > 3
1925
1926 // Load values from matrix A and matrix B
1927 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1928 src_addr.s1 += src1_stride_y;
1929
1930 // Multiply and accumulate
1931 acc0.s0 = fma(a0.s3, b0.s0, acc0.s0);
1932 acc0.s1 = fma(a0.s3, b0.s1, acc0.s1);
1933 acc0.s2 = fma(a0.s3, b0.s2, acc0.s2);
1934 acc0.s3 = fma(a0.s3, b0.s3, acc0.s3);
1935
1936#if M0 > 1
1937
1938 acc1.s0 = fma(a1.s3, b0.s0, acc1.s0);
1939 acc1.s1 = fma(a1.s3, b0.s1, acc1.s1);
1940 acc1.s2 = fma(a1.s3, b0.s2, acc1.s2);
1941 acc1.s3 = fma(a1.s3, b0.s3, acc1.s3);
1942
1943#endif // M0 > 1
1944#if M0 > 2
1945
1946 acc2.s0 = fma(a2.s3, b0.s0, acc2.s0);
1947 acc2.s1 = fma(a2.s3, b0.s1, acc2.s1);
1948 acc2.s2 = fma(a2.s3, b0.s2, acc2.s2);
1949 acc2.s3 = fma(a2.s3, b0.s3, acc2.s3);
1950
1951#endif // M0 > 2
1952#if M0 > 3
1953
1954 acc3.s0 = fma(a3.s3, b0.s0, acc3.s0);
1955 acc3.s1 = fma(a3.s3, b0.s1, acc3.s1);
1956 acc3.s2 = fma(a3.s3, b0.s2, acc3.s2);
1957 acc3.s3 = fma(a3.s3, b0.s3, acc3.s3);
1958#endif // M0 > 3
1959
1960 src_addr.s0 += 4 * sizeof(float);
1961 }
1962
1963 for(; i < (int)K; ++i)
1964 {
1965#if defined(REINTERPRET_INPUT_AS_3D)
1966 // Load values from matrix A
1967 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
1968#if M0 > 1
1969 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
1970#endif // M0 > 1
1971#if M0 > 2
1972 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
1973#endif // M0 > 2
1974#if M0 > 3
1975 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
1976#endif // M0 > 3
1977#else // defined(REINTERPRET_INPUT_AS_3D)
1978 // Load values from matrix A
1979 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
1980#if M0 > 1
1981 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
1982#endif // M0 > 1
1983#if M0 > 2
1984 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
1985#endif // M0 > 2
1986#if M0 > 3
1987 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
1988#endif // M0 > 3
1989#endif // defined(REINTERPRET_INPUT_AS_3D)
1990
1991 // Load values from matrix B
1992 float4 b0 = vload4(0, (__global float *)(src1_ptr + src_addr.s1));
1993 src_addr.s1 += src1_stride_y;
1994
1995 // Multiply and accumulate
1996 acc0.s0 = fma(a0, b0.s0, acc0.s0);
1997 acc0.s1 = fma(a0, b0.s1, acc0.s1);
1998 acc0.s2 = fma(a0, b0.s2, acc0.s2);
1999 acc0.s3 = fma(a0, b0.s3, acc0.s3);
2000#if M0 > 1
2001 acc1.s0 = fma(a1, b0.s0, acc1.s0);
2002 acc1.s1 = fma(a1, b0.s1, acc1.s1);
2003 acc1.s2 = fma(a1, b0.s2, acc1.s2);
2004 acc1.s3 = fma(a1, b0.s3, acc1.s3);
2005#endif // M0 > 1
2006#if M0 > 2
2007 acc2.s0 = fma(a2, b0.s0, acc2.s0);
2008 acc2.s1 = fma(a2, b0.s1, acc2.s1);
2009 acc2.s2 = fma(a2, b0.s2, acc2.s2);
2010 acc2.s3 = fma(a2, b0.s3, acc2.s3);
2011#endif // M0 > 2
2012#if M0 > 3
2013 acc3.s0 = fma(a3, b0.s0, acc3.s0);
2014 acc3.s1 = fma(a3, b0.s1, acc3.s1);
2015 acc3.s2 = fma(a3, b0.s2, acc3.s2);
2016 acc3.s3 = fma(a3, b0.s3, acc3.s3);
2017#endif // M0 > 3
2018
2019 src_addr.s0 += sizeof(float);
2020 }
2021
2022 int z = get_global_id(2);
2023
SiCong Li4abc9d12020-10-28 14:19:28 +00002024 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00002025 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
2026 PARTIAL_STORE_M0) * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00002027
2028 uint4 zout = 0;
2029
2030#if defined(REINTERPRET_OUTPUT_AS_3D)
2031 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2032 // in order to take into account the presence of possible cross plane paddings
2033 //
2034 // | |
2035 // | plane0 |
2036 // | |
2037 // |__________________|
2038 // |******************|
2039 // | cross_plane_pad |
2040 // |******************|
2041 // | |
2042 // | plane1 |
2043 // | |
2044 // |__________________|
2045
SiCong Li0ea50e32020-11-05 09:18:11 +00002046 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
2047 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002048 zout = min(DEPTH_GEMM3D - 1, zout);
2049
2050 // Add offset due to the cross plane paddings
2051 zout *= (dst_cross_plane_pad * dst_stride_y);
2052
2053 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2054 // multiply dst_stride_z by DEPTH_GEMM3D
2055 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2056#else // defined(REINTERPRET_OUTPUT_AS_3D)
2057 // Add offset for batched GEMM
2058 dst_addr += z * dst_stride_z;
2059#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2060
2061 // Multiply by the weight of matrix-matrix product and store the result
2062#if defined(ALPHA)
2063 SCALE_BLOCK(M0, float, acc, ALPHA);
2064#endif // defined(ALPHA)
2065
2066 // Add beta*bias
2067#if defined(BETA)
2068 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2069
2070#if defined(BROADCAST_BIAS)
2071 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float));
2072
2073 LOAD_BLOCK(1, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2074
2075#ifndef UNIT_BETA
2076 SCALE_BLOCK(1, float, bias, BETA);
2077#endif // UNIT_BIAS
2078
2079 // acc = acc + bias[broadcasted]
2080 ADD_BLOCK_BROADCAST(M0, acc, bias0);
2081
2082#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00002083 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)4 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
2084 PARTIAL_STORE_M0)
2085 * src2_stride_y)
2086 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00002087
2088 LOAD_BLOCK(M0, 4, float, bias, src2_addr, 0, src2_stride_y, zero);
2089
2090#ifndef UNIT_BETA
2091 SCALE_BLOCK(M0, float, bias, BETA);
2092#endif // UNIT_BIAS
2093
2094 // acc = acc + bias
2095 ADD_BLOCK(M0, acc, bias);
2096
2097#endif // defined(BROADCAST_BIAS)
2098#endif // defined(BETA)
2099
2100#if defined(ACTIVATION_TYPE)
2101 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, float, VEC_SIZE, acc, A_VAL, B_VAL);
2102#endif // defined(ACTIVATION_TYPE)
2103
2104 // Store the output block
SiCong Li0ea50e32020-11-05 09:18:11 +00002105 const bool cond_y = get_global_id(1) == 0;
2106 const bool cond_x = ((get_global_id(0) + 1) * 4 >= N);
2107 STORE_BLOCK_BOUNDARY_AWARE(M0, 4, float, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00002108}
2109
2110/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not been reshaped
2111 *
2112 * @note This OpenCL kernel works with the 32-bit floating point data type (float) and uses the fma units.
2113 * This OpenCL kernel is optimized for Bifrost when the number of matrix B columns is less or equal to 1000.
2114 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
SiCong Li0ea50e32020-11-05 09:18:11 +00002115 * @note This kernel processed a fixed number of elements along x: -DN0=2.
2116 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
2117 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
2118 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
2119 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00002120 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2121 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
2122 *
2123 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2124 * The activation function is performed after the bias addition
2125 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2126 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2127 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2128 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2129 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2130 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2131 *
2132 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F32
2133 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2134 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2135 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2136 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2137 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2138 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2139 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2140 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2141 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2142 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2143 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2144 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2145 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2146 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2147 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2148 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2149 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2150 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2151 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2152 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2153 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2154 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2155 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2156 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2157 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2158 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2159 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2160 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2161 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2162 */
2163__kernel void gemm_mm_floating_point_f32_bifrost_1000(IMAGE_DECLARATION(src0),
2164 IMAGE_DECLARATION(src1),
2165#if defined(BETA)
2166 IMAGE_DECLARATION(src2),
2167#endif // defined(BETA)
2168 IMAGE_DECLARATION(dst),
2169 uint src0_stride_z,
2170 uint src1_stride_z,
2171#if defined(BETA)
2172 uint src2_stride_z,
2173#endif //defined(BETA)
2174 uint dst_stride_z
2175#if defined(REINTERPRET_INPUT_AS_3D)
2176 ,
2177 uint src_cross_plane_pad
2178#endif // REINTERPRET_INPUT_AS_3D
2179#if defined(REINTERPRET_OUTPUT_AS_3D)
2180 ,
2181 uint dst_cross_plane_pad
2182#endif // REINTERPRET_OUTPUT_AS_3D
2183 )
2184{
2185 // Requires 2 N0, C vect2, A vect4, B (2 vload2) // to fix for M0 > 1
2186 int idx = get_global_id(0) * N0;
2187
2188 // Compute starting address for matrix A and Matrix B
2189 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2190
2191 // Update address for the matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00002192 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00002193
2194 // Update address for the matrix B
2195 src_addr.s1 += idx * sizeof(float);
2196
2197#if defined(REINTERPRET_INPUT_AS_3D)
2198 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2199 // in order to take into account the presence of possible cross plane paddings
2200 //
2201 // | |
2202 // | plane0 |
2203 // | |
2204 // |__________________|
2205 // |******************|
2206 // | cross_plane_pad |
2207 // |******************|
2208 // | |
2209 // | plane1 |
2210 // | |
2211 // |__________________|
2212
SiCong Li0ea50e32020-11-05 09:18:11 +00002213 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
2214 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002215 zin = min(DEPTH_GEMM3D - 1, zin);
2216
2217 // Add offset due to the cross plane paddings
2218 zin *= (src_cross_plane_pad * src0_stride_y);
2219
2220 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2221 // multiply src0_stride_z by DEPTH_GEMM3D
2222 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2223
2224#else // defined(REINTERPRET_INPUT_AS_3D)
2225
2226 // Add offset for batched GEMM
2227 src_addr.s0 += get_global_id(2) * src0_stride_z;
2228
2229#endif // defined(REINTERPRET_INPUT_AS_3D)
2230
2231#if defined(MATRIX_B_DEPTH)
2232 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2233 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2234#else // defined(MATRIX_B_DEPTH)
2235 src_addr.s1 += get_global_id(2) * src1_stride_z;
2236#endif // defined(MATRIX_B_DEPTH)
2237
2238 // Initialize accumulators
2239 float2 acc0 = 0.0f;
2240#if M0 > 1
2241 float2 acc1 = 0.0f;
2242#endif // M0 > 1
2243#if M0 > 2
2244 float2 acc2 = 0.0f;
2245#endif // M0 > 2
2246#if M0 > 3
2247 float2 acc3 = 0.0f;
2248#endif // M0 > 3
2249
2250 // A and B src indices get incremented at the same time.
2251 int i = 0;
2252 for(; i <= ((int)K - 8); i += 8)
2253 {
2254#if defined(REINTERPRET_INPUT_AS_3D)
2255 // Load values from matrix A
2256 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + zin.s0));
2257#else // defined(REINTERPRET_INPUT_AS_3D)
2258 // Load values from matrix A
2259 float8 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0));
2260#endif // defined(REINTERPRET_INPUT_AS_3D)
2261
2262 // Load values from matrix B
2263 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2264 src_addr.s1 += src1_stride_y;
2265 float2 b1 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2266 src_addr.s1 += src1_stride_y;
2267 float2 b2 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2268 src_addr.s1 += src1_stride_y;
2269 float2 b3 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2270 src_addr.s1 += src1_stride_y;
2271 float2 b4 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2272 src_addr.s1 += src1_stride_y;
2273 float2 b5 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2274 src_addr.s1 += src1_stride_y;
2275 float2 b6 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2276 src_addr.s1 += src1_stride_y;
2277 float2 b7 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2278 src_addr.s1 += src1_stride_y;
2279
2280 // Multiply and accumulate
2281 acc0.s0 = fma(a0.s0, b0.s0, acc0.s0);
2282 acc0.s0 = fma(a0.s1, b1.s0, acc0.s0);
2283 acc0.s0 = fma(a0.s2, b2.s0, acc0.s0);
2284 acc0.s0 = fma(a0.s3, b3.s0, acc0.s0);
2285 acc0.s0 = fma(a0.s4, b4.s0, acc0.s0);
2286 acc0.s0 = fma(a0.s5, b5.s0, acc0.s0);
2287 acc0.s0 = fma(a0.s6, b6.s0, acc0.s0);
2288 acc0.s0 = fma(a0.s7, b7.s0, acc0.s0);
2289
2290 acc0.s1 = fma(a0.s0, b0.s1, acc0.s1);
2291 acc0.s1 = fma(a0.s1, b1.s1, acc0.s1);
2292 acc0.s1 = fma(a0.s2, b2.s1, acc0.s1);
2293 acc0.s1 = fma(a0.s3, b3.s1, acc0.s1);
2294 acc0.s1 = fma(a0.s4, b4.s1, acc0.s1);
2295 acc0.s1 = fma(a0.s5, b5.s1, acc0.s1);
2296 acc0.s1 = fma(a0.s6, b6.s1, acc0.s1);
2297 acc0.s1 = fma(a0.s7, b7.s1, acc0.s1);
2298
2299#if M0 > 1
2300#if defined(REINTERPRET_INPUT_AS_3D)
2301 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2302#else // defined(REINTERPRET_INPUT_AS_3D)
2303 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2304#endif // defined(REINTERPRET_INPUT_AS_3D)
2305 acc1.s0 = fma(a0.s0, b0.s0, acc1.s0);
2306 acc1.s0 = fma(a0.s1, b1.s0, acc1.s0);
2307 acc1.s0 = fma(a0.s2, b2.s0, acc1.s0);
2308 acc1.s0 = fma(a0.s3, b3.s0, acc1.s0);
2309 acc1.s0 = fma(a0.s4, b4.s0, acc1.s0);
2310 acc1.s0 = fma(a0.s5, b5.s0, acc1.s0);
2311 acc1.s0 = fma(a0.s6, b6.s0, acc1.s0);
2312 acc1.s0 = fma(a0.s7, b7.s0, acc1.s0);
2313
2314 acc1.s1 = fma(a0.s0, b0.s1, acc1.s1);
2315 acc1.s1 = fma(a0.s1, b1.s1, acc1.s1);
2316 acc1.s1 = fma(a0.s2, b2.s1, acc1.s1);
2317 acc1.s1 = fma(a0.s3, b3.s1, acc1.s1);
2318 acc1.s1 = fma(a0.s4, b4.s1, acc1.s1);
2319 acc1.s1 = fma(a0.s5, b5.s1, acc1.s1);
2320 acc1.s1 = fma(a0.s6, b6.s1, acc1.s1);
2321 acc1.s1 = fma(a0.s7, b7.s1, acc1.s1);
2322#endif // M0 > 1
2323#if M0 > 2
2324#if defined(REINTERPRET_INPUT_AS_3D)
2325 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2326#else // defined(REINTERPRET_INPUT_AS_3D)
2327 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2328#endif // defined(REINTERPRET_INPUT_AS_3D)
2329 acc2.s0 = fma(a0.s0, b0.s0, acc2.s0);
2330 acc2.s0 = fma(a0.s1, b1.s0, acc2.s0);
2331 acc2.s0 = fma(a0.s2, b2.s0, acc2.s0);
2332 acc2.s0 = fma(a0.s3, b3.s0, acc2.s0);
2333 acc2.s0 = fma(a0.s4, b4.s0, acc2.s0);
2334 acc2.s0 = fma(a0.s5, b5.s0, acc2.s0);
2335 acc2.s0 = fma(a0.s6, b6.s0, acc2.s0);
2336 acc2.s0 = fma(a0.s7, b7.s0, acc2.s0);
2337
2338 acc2.s1 = fma(a0.s0, b0.s1, acc2.s1);
2339 acc2.s1 = fma(a0.s1, b1.s1, acc2.s1);
2340 acc2.s1 = fma(a0.s2, b2.s1, acc2.s1);
2341 acc2.s1 = fma(a0.s3, b3.s1, acc2.s1);
2342 acc2.s1 = fma(a0.s4, b4.s1, acc2.s1);
2343 acc2.s1 = fma(a0.s5, b5.s1, acc2.s1);
2344 acc2.s1 = fma(a0.s6, b6.s1, acc2.s1);
2345 acc2.s1 = fma(a0.s7, b7.s1, acc2.s1);
2346#endif // M0 > 2
2347#if M0 > 3
2348#if defined(REINTERPRET_INPUT_AS_3D)
2349 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2350#else // defined(REINTERPRET_INPUT_AS_3D)
2351 a0 = vload8(0, (__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2352#endif // defined(REINTERPRET_INPUT_AS_3D)
2353 acc3.s0 = fma(a0.s0, b0.s0, acc3.s0);
2354 acc3.s0 = fma(a0.s1, b1.s0, acc3.s0);
2355 acc3.s0 = fma(a0.s2, b2.s0, acc3.s0);
2356 acc3.s0 = fma(a0.s3, b3.s0, acc3.s0);
2357 acc3.s0 = fma(a0.s4, b4.s0, acc3.s0);
2358 acc3.s0 = fma(a0.s5, b5.s0, acc3.s0);
2359 acc3.s0 = fma(a0.s6, b6.s0, acc3.s0);
2360 acc3.s0 = fma(a0.s7, b7.s0, acc3.s0);
2361
2362 acc3.s1 = fma(a0.s0, b0.s1, acc3.s1);
2363 acc3.s1 = fma(a0.s1, b1.s1, acc3.s1);
2364 acc3.s1 = fma(a0.s2, b2.s1, acc3.s1);
2365 acc3.s1 = fma(a0.s3, b3.s1, acc3.s1);
2366 acc3.s1 = fma(a0.s4, b4.s1, acc3.s1);
2367 acc3.s1 = fma(a0.s5, b5.s1, acc3.s1);
2368 acc3.s1 = fma(a0.s6, b6.s1, acc3.s1);
2369 acc3.s1 = fma(a0.s7, b7.s1, acc3.s1);
2370#endif // M0 > 3
2371
2372 src_addr.s0 += sizeof(float) * 8;
2373 }
2374 // float size increment
2375 for(; i < (int)K; ++i)
2376 {
2377#if defined(REINTERPRET_INPUT_AS_3D)
2378 // Load values from matrix A
2379 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2380#if M0 > 1
2381 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2382#endif // M0 > 1
2383#if M0 > 2
2384 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2385#endif // M0 > 2
2386#if M0 > 3
2387 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2388#endif // M0 > 3
2389#else // defined(REINTERPRET_INPUT_AS_3D)
2390 // Load values from matrix A
2391 float a0 = *((__global float *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2392#if M0 > 1
2393 float a1 = *((__global float *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2394#endif // M0 > 1
2395#if M0 > 2
2396 float a2 = *((__global float *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2397#endif // M0 > 2
2398#if M0 > 3
2399 float a3 = *((__global float *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2400#endif // M0 > 3
2401#endif // defined(REINTERPRET_INPUT_AS_3D)
2402
2403 // Load values from matrix B
2404 float2 b0 = vload2(0, (__global float *)(src1_ptr + src_addr.s1));
2405 src_addr.s1 += src1_stride_y;
2406
2407 // Multiply and accumulate
2408 acc0.s0 = fma(a0, b0.s0, acc0.s0);
2409 acc0.s1 = fma(a0, b0.s1, acc0.s1);
2410#if M0 > 1
2411 acc1.s0 = fma(a1, b0.s0, acc1.s0);
2412 acc1.s1 = fma(a1, b0.s1, acc1.s1);
2413#endif // M0 > 1
2414#if M0 > 2
2415 acc2.s0 = fma(a2, b0.s0, acc2.s0);
2416 acc2.s1 = fma(a2, b0.s1, acc2.s1);
2417#endif // M0 > 2
2418#if M0 > 3
2419 acc3.s0 = fma(a3, b0.s0, acc3.s0);
2420 acc3.s1 = fma(a3, b0.s1, acc3.s1);
2421#endif // M0 > 3
2422
2423 src_addr.s0 += sizeof(float);
2424 }
2425
2426 int z = get_global_id(2);
2427
SiCong Li4abc9d12020-10-28 14:19:28 +00002428 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00002429 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
2430 PARTIAL_STORE_M0) * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00002431
2432 uint4 zout = 0;
2433
2434#if defined(REINTERPRET_OUTPUT_AS_3D)
2435
2436 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2437 // in order to take into account the presence of possible cross plane paddings
2438 //
2439 // | |
2440 // | plane0 |
2441 // | |
2442 // |__________________|
2443 // |******************|
2444 // | cross_plane_pad |
2445 // |******************|
2446 // | |
2447 // | plane1 |
2448 // | |
2449 // |__________________|
2450
SiCong Li0ea50e32020-11-05 09:18:11 +00002451 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
2452 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002453 zout = min(DEPTH_GEMM3D - 1, zout);
2454
2455 // Add offset due to the cross plane paddings
2456 zout *= (dst_cross_plane_pad * dst_stride_y);
2457
2458 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2459 // multiply dst_stride_z by DEPTH_GEMM3D
2460 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2461#else // defined(REINTERPRET_OUTPUT_AS_3D)
2462 // Add offset for batched GEMM
2463 dst_addr += z * dst_stride_z;
2464#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2465
2466 // Multiply by the weight of matrix-matrix product and store the result
2467#if defined(ALPHA)
2468 SCALE_BLOCK(M0, float, acc, ALPHA);
2469#endif // defined(ALPHA)
2470
2471 // Add beta*bias
2472#if defined(BETA)
2473 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2474
2475#if defined(BROADCAST_BIAS)
2476 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float));
2477
2478 LOAD_BLOCK(1, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
2479
2480#ifndef UNIT_BETA
2481 SCALE_BLOCK(1, float, bias, BETA);
2482#endif // UNIT_BIAS
2483
2484 // acc = acc + bias[broadcasted]
2485 ADD_BLOCK_BROADCAST(M0, acc, bias0);
2486
2487#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00002488 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)2 * sizeof(float)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
2489 PARTIAL_STORE_M0)
2490 * src2_stride_y)
2491 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00002492
2493 LOAD_BLOCK(M0, 2, float, bias, src2_addr, 0, src2_stride_y, zero);
2494
2495#ifndef UNIT_BETA
2496 SCALE_BLOCK(M0, float, bias, BETA);
2497#endif // UNIT_BIAS
2498
2499 // acc = acc + bias
2500 ADD_BLOCK(M0, acc, bias);
2501
2502#endif // defined(BROADCAST_BIAS)
2503#endif // defined(BETA)
2504
2505#if defined(ACTIVATION_TYPE)
2506 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, float, VEC_SIZE, acc, A_VAL, B_VAL);
2507#endif // defined(ACTIVATION_TYPE)
2508
2509 // Store the output block
SiCong Li0ea50e32020-11-05 09:18:11 +00002510 const bool cond_y = get_global_id(1) == 0;
2511 const bool cond_x = ((get_global_id(0) + 1) * 2 >= N);
2512 STORE_BLOCK_BOUNDARY_AWARE(M0, 2, float, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00002513}
2514
2515#if defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
2516/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2517 *
2518 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and accumulating the result in a 32 floating point variable.
2519 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
SiCong Li0ea50e32020-11-05 09:18:11 +00002520 * @note This kernel processed a fixed number of elements along x: -DN0=8.
2521 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
2522 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
2523 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
2524 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00002525 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2526 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
2527 *
2528 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2529 * The activation function is performed after the bias addition
2530 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2531 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2532 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2533 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2534 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2535 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2536 *
2537 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2538 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2539 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2540 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2541 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2542 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2543 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2544 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2545 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2546 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2547 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2548 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2549 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2550 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2551 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2552 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2553 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2554 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2555 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2556 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2557 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2558 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2559 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2560 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2561 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2562 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2563 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2564 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2565 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2566 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2567 */
2568__kernel void gemm_mm_floating_point_f16_bifrost_acc32(IMAGE_DECLARATION(src0),
2569 IMAGE_DECLARATION(src1),
2570#if defined(BETA)
2571 IMAGE_DECLARATION(src2),
2572#endif // defined(BETA)
2573 IMAGE_DECLARATION(dst),
2574 uint src0_stride_z,
2575 uint src1_stride_z,
2576#if defined(BETA)
2577 uint src2_stride_z,
2578#endif //defined(BETA)
2579 uint dst_stride_z
2580#if defined(REINTERPRET_INPUT_AS_3D)
2581 ,
2582 uint src_cross_plane_pad
2583#endif // REINTERPRET_INPUT_AS_3D
2584#if defined(REINTERPRET_OUTPUT_AS_3D)
2585 ,
2586 uint dst_cross_plane_pad
2587#endif // REINTERPRET_OUTPUT_AS_3D
2588 )
2589{
2590 int idx = get_global_id(0) * N0;
2591
2592 // Compute starting address for matrix A and Matrix B
2593 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2594
2595 // Update address for the matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00002596 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00002597
2598 // Update address for the matrix B
2599 src_addr.s1 += idx * sizeof(half);
2600
2601#if defined(REINTERPRET_INPUT_AS_3D)
2602 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2603 // in order to take into account the presence of possible cross plane paddings
2604 //
2605 // | |
2606 // | plane0 |
2607 // | |
2608 // |__________________|
2609 // |******************|
2610 // | cross_plane_pad |
2611 // |******************|
2612 // | |
2613 // | plane1 |
2614 // | |
2615 // |__________________|
2616
SiCong Li0ea50e32020-11-05 09:18:11 +00002617 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
2618 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002619 zin = min(DEPTH_GEMM3D - 1, zin);
2620
2621 // Add offset due to the cross plane paddings
2622 zin *= (src_cross_plane_pad * src0_stride_y);
2623
2624 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2625 // multiply src0_stride_z by DEPTH_GEMM3D
2626 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2627
2628#else // defined(REINTERPRET_INPUT_AS_3D)
2629
2630 // Add offset for batched GEMM
2631 src_addr.s0 += get_global_id(2) * src0_stride_z;
2632
2633#endif // defined(REINTERPRET_INPUT_AS_3D)
2634
2635#if defined(MATRIX_B_DEPTH)
2636 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
2637 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
2638#else // defined(MATRIX_B_DEPTH)
2639 src_addr.s1 += get_global_id(2) * src1_stride_z;
2640#endif // defined(MATRIX_B_DEPTH)
2641
2642 float8 acc0 = 0.0h;
2643#if M0 > 1
2644 float8 acc1 = 0.0h;
2645#endif // M0 > 1
2646#if M0 > 2
2647 float8 acc2 = 0.0h;
2648#endif // M0 > 2
2649#if M0 > 3
2650 float8 acc3 = 0.0h;
2651#endif // M0 > 3
2652
2653 int i = 0;
2654 for(; i <= ((int)K - 4); i += 4)
2655 {
2656#if defined(REINTERPRET_INPUT_AS_3D)
2657 // Load values from matrix A
2658 LOAD_BLOCK(M0, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
2659#else // defined(REINTERPRET_INPUT_AS_3D)
2660 // Load values from matrix A
2661 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2662#if M0 > 1
2663 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2664#endif // M0 > 1
2665#if M0 > 2
2666 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2667#endif // M0 > 2
2668#if M0 > 3
2669 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2670#endif // M0 > 3
2671#endif // defined(REINTERPRET_INPUT_AS_3D)
2672
2673 // Load values from matrix B
2674 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2675 src_addr.s1 += src1_stride_y;
2676
2677 // Accumulate
2678 acc0 = fma(b0, (float8)a0.s0, acc0);
2679#if M0 > 1
2680 acc1 = fma(b0, (float8)a1.s0, acc1);
2681#endif // M0 > 1
2682#if M0 > 2
2683 acc2 = fma(b0, (float8)a2.s0, acc2);
2684#endif // M0 > 2
2685#if M0 > 3
2686 acc3 = fma(b0, (float8)a3.s0, acc3);
2687#endif // M0 > 3
2688
2689 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2690 src_addr.s1 += src1_stride_y;
2691 acc0 = fma(b0, (float8)a0.s1, acc0);
2692#if M0 > 1
2693 acc1 = fma(b0, (float8)a1.s1, acc1);
2694#endif // M0 > 1
2695#if M0 > 2
2696 acc2 = fma(b0, (float8)a2.s1, acc2);
2697#endif // M0 > 2
2698#if M0 > 3
2699 acc3 = fma(b0, (float8)a3.s1, acc3);
2700#endif // M0 > 3
2701
2702 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2703 src_addr.s1 += src1_stride_y;
2704 acc0 = fma(b0, (float8)a0.s2, acc0);
2705#if M0 > 1
2706 acc1 = fma(b0, (float8)a1.s2, acc1);
2707#endif // M0 > 1
2708#if M0 > 2
2709 acc2 = fma(b0, (float8)a2.s2, acc2);
2710#endif // M0 > 2
2711#if M0 > 3
2712 acc3 = fma(b0, (float8)a3.s2, acc3);
2713#endif // M0 > 3
2714
2715 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2716 src_addr.s1 += src1_stride_y;
2717 acc0 = fma(b0, (float8)a0.s3, acc0);
2718#if M0 > 1
2719 acc1 = fma(b0, (float8)a1.s3, acc1);
2720#endif // M0 > 1
2721#if M0 > 2
2722 acc2 = fma(b0, (float8)a2.s3, acc2);
2723#endif // M0 > 2
2724#if M0 > 3
2725 acc3 = fma(b0, (float8)a3.s3, acc3);
2726#endif // M0 > 3
2727
2728 src_addr.s0 += 4 * sizeof(half);
2729 }
2730
2731 for(; i < (int)K; ++i)
2732 {
2733#if defined(REINTERPRET_INPUT_AS_3D)
2734 // Load values from matrix A
2735 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
2736#if M0 > 1
2737 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
2738#endif // M0 > 1
2739#if M0 > 2
2740 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
2741#endif // M0 > 2
2742#if M0 > 3
2743 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
2744#endif // M0 > 3
2745#else // defined(REINTERPRET_INPUT_AS_3D)
2746 // Load values from matrix A
2747 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
2748#if M0 > 1
2749 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
2750#endif // M0 > 1
2751#if M0 > 2
2752 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
2753#endif // M0 > 2
2754#if M0 > 3
2755 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
2756#endif // M0 > 3
2757#endif // defined(REINTERPRET_INPUT_AS_3D)
2758
2759 // Load values from matrix B
2760 float8 b0 = convert_float8(vload8(0, (__global half *)(src1_ptr + src_addr.s1)));
2761
2762 src_addr += (int2)(sizeof(half), src1_stride_y);
2763
2764 // Accumulate
2765 acc0 = fma(b0, (float8)a0, acc0); // b0 * (half8)a0;
2766#if M0 > 1
2767 acc1 = fma(b0, (float8)a1, acc1); // b0 * (half8)a1;
2768#endif // M0 > 1
2769#if M0 > 2
2770 acc2 = fma(b0, (float8)a2, acc2); // b0 * (half8)a2;
2771#endif // M0 > 2
2772#if M0 > 3
2773 acc3 = fma(b0, (float8)a3, acc3); // b0 * (half8)a3;
2774#endif // M0 > 3
2775 }
2776
2777 int z = get_global_id(2);
2778
SiCong Li4abc9d12020-10-28 14:19:28 +00002779 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00002780 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00002781
2782 uint4 zout = 0;
2783
2784#if defined(REINTERPRET_OUTPUT_AS_3D)
2785
2786 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
2787 // in order to take into account the presence of possible cross plane paddings
2788 //
2789 // | |
2790 // | plane0 |
2791 // | |
2792 // |__________________|
2793 // |******************|
2794 // | cross_plane_pad |
2795 // |******************|
2796 // | |
2797 // | plane1 |
2798 // | |
2799 // |__________________|
2800
SiCong Li0ea50e32020-11-05 09:18:11 +00002801 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
2802 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002803 zout = min(DEPTH_GEMM3D - 1, zout);
2804
2805 // Add offset due to the cross plane paddings
2806 zout *= (dst_cross_plane_pad * dst_stride_y);
2807
2808 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2809 // multiply dst_stride_z by DEPTH_GEMM3D
2810 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
2811#else // defined(REINTERPRET_OUTPUT_AS_3D)
2812 // Add offset for batched GEMM
2813 dst_addr += z * dst_stride_z;
2814#endif // defined(REINTERPRET_OUTPUT_AS_3D)
2815
2816 // Multiply by the weight of matrix-matrix product and store the result
2817#if defined(ALPHA)
2818 SCALE_BLOCK(M0, float, acc, ALPHA);
2819#endif // defined(ALPHA)
2820
2821#if defined(BETA)
2822 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
2823
2824#if defined(BROADCAST_BIAS)
2825 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
2826
2827 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
2828
2829 float8 bias_f0 = convert_float8(bias0);
2830
2831#ifndef UNIT_BETA
2832 SCALE_BLOCK(1, float, bias_f, BETA);
2833#endif // UNIT_BIAS
2834
2835 // acc = acc + bias[broadcasted]
2836 ADD_BLOCK_BROADCAST(M0, acc, bias_f0);
2837
2838#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00002839 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
2840 PARTIAL_STORE_M0)
2841 * src2_stride_y)
2842 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00002843
2844 LOAD_BLOCK(M0, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
2845
2846 float8 bias_f0 = convert_float8(bias0);
2847#if M0 > 1
2848 float8 bias_f1 = convert_float8(bias1);
2849#endif // M0 > 1
2850#if M0 > 2
2851 float8 bias_f2 = convert_float8(bias2);
2852#endif // M0 > 2
2853#if M0 > 3
2854 float8 bias_f3 = convert_float8(bias3);
2855#endif // M0 > 3
2856
2857#ifndef UNIT_BETA
2858 SCALE_BLOCK(M0, float, bias_f, BETA);
2859#endif // UNIT_BIAS
2860
2861 // acc = acc + bias
2862 ADD_BLOCK(M0, acc, bias_f);
2863
2864#endif // defined(BROADCAST_BIAS)
2865#endif // defined(BETA)
2866
2867 half8 acc_h0 = convert_half8(acc0);
2868#if M0 > 1
2869 half8 acc_h1 = convert_half8(acc1);
2870#endif // M0 > 1
2871#if M0 > 2
2872 half8 acc_h2 = convert_half8(acc2);
2873#endif // M0 > 2
2874#if M0 > 3
2875 half8 acc_h3 = convert_half8(acc3);
2876#endif // M0 > 3
2877
2878#if defined(ACTIVATION_TYPE)
2879 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, half, VEC_SIZE, acc_h, A_VAL, B_VAL);
2880#endif // defined(ACTIVATION_TYPE)
2881
2882 // Store the output block
SiCong Li0ea50e32020-11-05 09:18:11 +00002883 const bool cond_y = get_global_id(1) == 0;
2884 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
2885 STORE_BLOCK_BOUNDARY_AWARE(M0, 8, half, acc_h, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00002886}
2887
2888/** This OpenCL kernel computes the matrix by matrix multiplication between the matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
2889 *
2890 * @note This OpenCL kernel works with the 16-bit floating point data type (half) and uses the fma units.
2891 * @note The number of elements processed along the x and y directions must be passed at compile time using -DN0 and -DM0.
SiCong Li0ea50e32020-11-05 09:18:11 +00002892 * @note This kernel processed a fixed number of elements along x: -DN0=8.
2893 * @note The number of columns of matrix A and the number of columns of the matrix B need to be passed at compile time using -DK and -DN
2894 * @note The size of the partial store block in y must be passed at compile time using -DPARTIAL_STORE_M0 (e.g. -DPARTIAL_STORE_M0=1)
2895 * @note The size of the partial store block in x must be passed at compile time using -DPARTIAL_STORE_N0 (e.g. -DPARTIAL_STORE_N0=1)
2896 * @note The optional alpha's value need to be passed at compile time using -DALPHA
SiCong Li4abc9d12020-10-28 14:19:28 +00002897 * @note In case the matrix B has 3 dimensions and the matrix A more than 3, in order to avoid out-of-bounds reads, the number of channels of matrix B must be passed at compile time using MATRIX_B_DEPTH (e.g. -DMATRIX_B_DEPTH=16)
2898 * This case can happen when GEMM is used to perform the element-wise multiplication through a batched matrix multiplication (2D Winograd) and we have multiple inputs (e.g. a = [K, M, 16, Batches], b = [N, K, 16])
2899 *
2900 * @note If the activation type were passed at compile time through -DACTIVATION_TYPE (e.g. -DACTIVATION_TYPE=RELU), A, B variables, required by some activation functions, should be passed at compile time as well using -DA_VAL= and -DB_VAL= respectively.
2901 * The activation function is performed after the bias addition
2902 * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
2903 * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
2904 * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
2905 * -# HEIGHT_GEMM3D: The height of the output in case it has to be reinterpreted as a 3D tensor.
2906 * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
2907 * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns matrix A NOT reshaped
2908 *
2909 * @param[in] src0_ptr Pointer to the source matrix. Supported data types: F16
2910 * @param[in] src0_stride_x Stride of the source matrix in X dimension (in bytes)
2911 * @param[in] src0_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2912 * @param[in] src0_stride_y Stride of the source matrix in Y dimension (in bytes)
2913 * @param[in] src0_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2914 * @param[in] src0_offset_first_element_in_bytes The offset of the first element in the source matrix
2915 * @param[in] src1_ptr Pointer to the source matrix. Supported data types: same as @p src0_ptr
2916 * @param[in] src1_stride_x Stride of the source matrix in X dimension (in bytes)
2917 * @param[in] src1_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2918 * @param[in] src1_stride_y Stride of the source matrix in Y dimension (in bytes)
2919 * @param[in] src1_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2920 * @param[in] src1_offset_first_element_in_bytes The offset of the first element in the source matrix
2921 * @param[in] src2_ptr (Optional) Pointer to the bias matrix. Supported data type: same as @p lhs_ptr
2922 * @param[in] src2_stride_x (Optional) Stride of the bias matrix in X dimension (in bytes)
2923 * @param[in] src2_step_x (Optional) src2_stride_x * number of elements along X processed per workitem(in bytes)
2924 * @param[in] src2_stride_y (Optional) Stride of the bias matrix in Y dimension (in bytes)
2925 * @param[in] src2_step_y (Optional) src2_stride_y * number of elements along Y processed per workitem(in bytes)
2926 * @param[in] src2_offset_first_element_in_bytes (Optional) The offset of the first element in the bias matrix
2927 * @param[out] dst_ptr Pointer to the destination matrix Supported data types: same as @p src0_ptr
2928 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
2929 * @param[in] dst_step_x dst_gx_stride_x * number of elements along X processed per workitem(in bytes)
2930 * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
2931 * @param[in] dst_step_y dst_gx_stride_y * number of elements along Y processed per workitem(in bytes)
2932 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination matrix
2933 * @param[in] src0_stride_z Stride of the source matrix in Z dimension (in bytes)
2934 * @param[in] src1_stride_z Stride of the source matrix in Z dimension (in bytes)
2935 * @param[in] src2_stride_z (Optional) Stride of the bias matrix in Z dimension (in bytes)
2936 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2937 * @param[in] src_cross_plane_pad (Optional) Bottom paddings in unit of elements for the input tensor (only if defined REINTERPRET_INPUT_AS_3D)
2938 * @param[in] dst_cross_plane_pad (Optional) Bottom paddings in unit of elements (only if defined REINTERPRET_OUTPUT_AS_3D)
2939 */
2940__kernel void gemm_mm_floating_point_f16_bifrost(IMAGE_DECLARATION(src0),
2941 IMAGE_DECLARATION(src1),
2942#if defined(BETA)
2943 IMAGE_DECLARATION(src2),
2944#endif // defined(BETA)
2945 IMAGE_DECLARATION(dst),
2946 uint src0_stride_z,
2947 uint src1_stride_z,
2948#if defined(BETA)
2949 uint src2_stride_z,
2950#endif //defined(BETA)
2951 uint dst_stride_z
2952#if defined(REINTERPRET_INPUT_AS_3D)
2953 ,
2954 uint src_cross_plane_pad
2955#endif // REINTERPRET_INPUT_AS_3D
2956#if defined(REINTERPRET_OUTPUT_AS_3D)
2957 ,
2958 uint dst_cross_plane_pad
2959#endif // REINTERPRET_OUTPUT_AS_3D
2960 )
2961{
2962 int idx = get_global_id(0) * N0;
2963
2964 // Compute starting address for matrix A and Matrix B
2965 int2 src_addr = ((int2)(src0_offset_first_element_in_bytes, src1_offset_first_element_in_bytes));
2966
2967 // Update address for the matrix A
SiCong Li0ea50e32020-11-05 09:18:11 +00002968 src_addr.s0 += COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * src0_stride_y;
SiCong Li4abc9d12020-10-28 14:19:28 +00002969
2970 // Update address for the matrix B
2971 src_addr.s1 += idx * sizeof(half);
2972
2973#if defined(REINTERPRET_INPUT_AS_3D)
2974 // Since we load a 2D input tile from a 3D tensor, we need to check when the plane changes across the z dimension
2975 // in order to take into account the presence of possible cross plane paddings
2976 //
2977 // | |
2978 // | plane0 |
2979 // | |
2980 // |__________________|
2981 // |******************|
2982 // | cross_plane_pad |
2983 // |******************|
2984 // | |
2985 // | plane1 |
2986 // | |
2987 // |__________________|
2988
SiCong Li0ea50e32020-11-05 09:18:11 +00002989 // The plane (zin) is calculated dividing row by HEIGHT_GEMM3D
2990 uint4 zin = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00002991 zin = min(DEPTH_GEMM3D - 1, zin);
2992
2993 // Add offset due to the cross plane paddings
2994 zin *= (src_cross_plane_pad * src0_stride_y);
2995
2996 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
2997 // multiply src0_stride_z by DEPTH_GEMM3D
2998 src_addr.s0 += get_global_id(2) * src0_stride_z * DEPTH_GEMM3D;
2999
3000#else // defined(REINTERPRET_INPUT_AS_3D)
3001
3002 // Add offset for batched GEMM
3003 src_addr.s0 += get_global_id(2) * src0_stride_z;
3004
3005#endif // defined(REINTERPRET_INPUT_AS_3D)
3006
3007#if defined(MATRIX_B_DEPTH)
3008 // Do not slide matrix B if the matrix B has 3 dimensions and matrix A more than 3
3009 src_addr.s1 += (get_global_id(2) % MATRIX_B_DEPTH) * src1_stride_z;
3010#else // defined(MATRIX_B_DEPTH)
3011 src_addr.s1 += get_global_id(2) * src1_stride_z;
3012#endif // defined(MATRIX_B_DEPTH)
3013
3014 half8 acc0 = 0.0h;
3015#if M0 > 1
3016 half8 acc1 = 0.0h;
3017#endif // M0 > 1
3018#if M0 > 2
3019 half8 acc2 = 0.0h;
3020#endif // M0 > 2
3021#if M0 > 3
3022 half8 acc3 = 0.0h;
3023#endif // M0 > 3
3024
3025 int i = 0;
3026 for(; i <= ((int)K - 4); i += 4)
3027 {
3028#if defined(REINTERPRET_INPUT_AS_3D)
3029 // Load values from matrix A
3030 LOAD_BLOCK(M0, 4, half, a, src0_ptr, src_addr.s0, src0_stride_y, zin.s);
3031#else // defined(REINTERPRET_INPUT_AS_3D)
3032 // Load values from matrix A
3033 half4 a0 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3034#if M0 > 1
3035 half4 a1 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3036#endif // M0 > 1
3037#if M0 > 2
3038 half4 a2 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3039#endif // M0 > 2
3040#if M0 > 3
3041 half4 a3 = vload4(0, (__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3042#endif // M0 > 3
3043#endif // defined(REINTERPRET_INPUT_AS_3D)
3044
3045 // Load values from matrix B
3046 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3047 src_addr.s1 += src1_stride_y;
3048
3049 // Accumulate
3050 acc0 = fma(b0, (half8)a0.s0, acc0);
3051#if M0 > 1
3052 acc1 = fma(b0, (half8)a1.s0, acc1);
3053#endif // M0 > 1
3054#if M0 > 2
3055 acc2 = fma(b0, (half8)a2.s0, acc2);
3056#endif // M0 > 2
3057#if M0 > 3
3058 acc3 = fma(b0, (half8)a3.s0, acc3);
3059#endif // M0 > 3
3060
3061 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3062 src_addr.s1 += src1_stride_y;
3063 acc0 = fma(b0, (half8)a0.s1, acc0);
3064#if M0 > 1
3065 acc1 = fma(b0, (half8)a1.s1, acc1);
3066#endif // M0 > 1
3067#if M0 > 2
3068 acc2 = fma(b0, (half8)a2.s1, acc2);
3069#endif // M0 > 2
3070#if M0 > 3
3071 acc3 = fma(b0, (half8)a3.s1, acc3);
3072#endif // M0 > 3
3073
3074 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3075 src_addr.s1 += src1_stride_y;
3076 acc0 = fma(b0, (half8)a0.s2, acc0);
3077#if M0 > 1
3078 acc1 = fma(b0, (half8)a1.s2, acc1);
3079#endif // M0 > 1
3080#if M0 > 2
3081 acc2 = fma(b0, (half8)a2.s2, acc2);
3082#endif // M0 > 2
3083#if M0 > 3
3084 acc3 = fma(b0, (half8)a3.s2, acc3);
3085#endif // M0 > 3
3086
3087 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3088 src_addr.s1 += src1_stride_y;
3089 acc0 = fma(b0, (half8)a0.s3, acc0);
3090#if M0 > 1
3091 acc1 = fma(b0, (half8)a1.s3, acc1);
3092#endif // M0 > 1
3093#if M0 > 2
3094 acc2 = fma(b0, (half8)a2.s3, acc2);
3095#endif // M0 > 2
3096#if M0 > 3
3097 acc3 = fma(b0, (half8)a3.s3, acc3);
3098#endif // M0 > 3
3099
3100 src_addr.s0 += 4 * sizeof(half);
3101 }
3102
3103 for(; i < (int)K; ++i)
3104 {
3105#if defined(REINTERPRET_INPUT_AS_3D)
3106 // Load values from matrix A
3107 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y + zin.s0));
3108#if M0 > 1
3109 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y + zin.s1));
3110#endif // M0 > 1
3111#if M0 > 2
3112 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y + zin.s2));
3113#endif // M0 > 2
3114#if M0 > 3
3115 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y + zin.s3));
3116#endif // M0 > 3
3117#else // defined(REINTERPRET_INPUT_AS_3D)
3118 // Load values from matrix A
3119 half a0 = *((__global half *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
3120#if M0 > 1
3121 half a1 = *((__global half *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
3122#endif // M0 > 1
3123#if M0 > 2
3124 half a2 = *((__global half *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
3125#endif // M0 > 2
3126#if M0 > 3
3127 half a3 = *((__global half *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
3128#endif // M0 > 3
3129#endif // defined(REINTERPRET_INPUT_AS_3D)
3130
3131 // Load values from matrix B
3132 half8 b0 = vload8(0, (__global half *)(src1_ptr + src_addr.s1));
3133
3134 src_addr += (int2)(sizeof(half), src1_stride_y);
3135
3136 // Accumulate
3137 acc0 = fma(b0, (half8)a0, acc0); // b0 * (half8)a0;
3138#if M0 > 1
3139 acc1 = fma(b0, (half8)a1, acc1); // b0 * (half8)a1;
3140#endif // M0 > 1
3141#if M0 > 2
3142 acc2 = fma(b0, (half8)a2, acc2); // b0 * (half8)a2;
3143#endif // M0 > 2
3144#if M0 > 3
3145 acc3 = fma(b0, (half8)a3, acc3); // b0 * (half8)a3;
3146#endif // M0 > 3
3147 }
3148
3149 int z = get_global_id(2);
3150
SiCong Li4abc9d12020-10-28 14:19:28 +00003151 // Compute dst address
SiCong Li0ea50e32020-11-05 09:18:11 +00003152 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0) * dst_stride_y);
SiCong Li4abc9d12020-10-28 14:19:28 +00003153
3154 uint4 zout = 0;
3155
3156#if defined(REINTERPRET_OUTPUT_AS_3D)
3157
3158 // Since we store a 2D output tile in a 3D tensor, we need to check when the plane changes across the z dimension
3159 // in order to take into account the presence of possible cross plane paddings
3160 //
3161 // | |
3162 // | plane0 |
3163 // | |
3164 // |__________________|
3165 // |******************|
3166 // | cross_plane_pad |
3167 // |******************|
3168 // | |
3169 // | plane1 |
3170 // | |
3171 // |__________________|
3172
SiCong Li0ea50e32020-11-05 09:18:11 +00003173 // The plane (zout) is calculated dividing row by HEIGHT_GEMM3D
3174 zout = ((uint4)(0, 1, 2, 3) + (uint4)(COMPUTE_M0_START_ROW(get_global_id(1), M0, PARTIAL_STORE_M0))) / (uint4)HEIGHT_GEMM3D;
SiCong Li4abc9d12020-10-28 14:19:28 +00003175 zout = min(DEPTH_GEMM3D - 1, zout);
3176
3177 // Add offset due to the cross plane paddings
3178 zout *= (dst_cross_plane_pad * dst_stride_y);
3179
3180 // Add offset for batched GEMM. The batches will be in the fourth dimension and for this reason we
3181 // multiply dst_stride_z by DEPTH_GEMM3D
3182 dst_addr += z * dst_stride_z * DEPTH_GEMM3D;
3183#else // defined(REINTERPRET_OUTPUT_AS_3D)
3184 // Add offset for batched GEMM
3185 dst_addr += z * dst_stride_z;
3186#endif // defined(REINTERPRET_OUTPUT_AS_3D)
3187
3188 // Multiply by the weight of matrix-matrix product and store the result
3189#if defined(ALPHA)
3190 SCALE_BLOCK(M0, half, acc, ALPHA);
3191#endif // defined(ALPHA)
3192
3193 // Add beta*bias
3194#if defined(BETA)
3195 REPEAT_VAR_INIT_TO_CONST(M0, uint, zero, 0);
3196
3197#if defined(BROADCAST_BIAS)
3198 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half));
3199
3200 LOAD_BLOCK(1, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3201
3202#ifndef UNIT_BETA
3203 SCALE_BLOCK(1, half, bias, BETA);
3204#endif // UNIT_BIAS
3205
3206 // acc = acc + bias[broadcasted]
3207 ADD_BLOCK_BROADCAST(M0, acc, bias0);
3208
3209#else // defined(BROADCAST_BIAS)
SiCong Li0ea50e32020-11-05 09:18:11 +00003210 __global uchar *src2_addr = src2_ptr + src2_offset_first_element_in_bytes + (get_global_id(0) * (uint)8 * sizeof(half)) + (COMPUTE_M0_START_ROW(get_global_id(1), M0,
3211 PARTIAL_STORE_M0)
3212 * src2_stride_y)
3213 + z * src2_stride_z;
SiCong Li4abc9d12020-10-28 14:19:28 +00003214
3215 LOAD_BLOCK(M0, 8, half, bias, src2_addr, 0, src2_stride_y, zero);
3216
3217#ifndef UNIT_BETA
3218 SCALE_BLOCK(M0, half, bias, BETA);
3219#endif // UNIT_BIAS
3220
3221 // acc = acc + bias
3222 ADD_BLOCK(M0, acc, bias);
3223
3224#endif // defined(BROADCAST_BIAS)
3225#endif // defined(BETA)
3226
3227#if defined(ACTIVATION_TYPE)
3228 ACTIVATION_BLOCK(M0, ACTIVATION_TYPE, half, VEC_SIZE, acc, A_VAL, B_VAL);
3229#endif // defined(ACTIVATION_TYPE)
3230
3231 // Store the output block
SiCong Li0ea50e32020-11-05 09:18:11 +00003232 const bool cond_y = get_global_id(1) == 0;
3233 const bool cond_x = ((get_global_id(0) + 1) * 8 >= N);
3234 STORE_BLOCK_BOUNDARY_AWARE(M0, 8, half, acc, dst_addr, dst_stride_y, zout.s, PARTIAL_STORE_M0, PARTIAL_STORE_N0, cond_y, cond_x);
SiCong Li4abc9d12020-10-28 14:19:28 +00003235}
3236#endif // defined(ARM_COMPUTE_OPENCL_FP16_ENABLED)
3237
SiCong Li0ea50e32020-11-05 09:18:11 +00003238#endif // defined(N) && defined(K) && defined(M0) && defined(N0) && defined(PARTIAL_STORE_M0) && defined(PARTIAL_STORE_N0)