blob: d195c14ccd9f225ae158c20adc56dba1c4ed963a [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26#if defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
27/** This OpenCL kernel performs Winograd output transform when the output tile is 2x2/2x1 or 1x2, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
28 *
29 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
30 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
31 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
32 * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
33 * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
34 *
35 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
36 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
37 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
38 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
39 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
40 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
41 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
42 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
43 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
44 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
45 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
46 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
47 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
48 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
49 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
50 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
51 */
52__kernel void winograd_output_transform_2x2_3x3_nchw(
53 TENSOR3D_DECLARATION(src),
54 TENSOR3D_DECLARATION(dst)
55#if defined(HAS_BIAS)
56 ,
57 VECTOR_DECLARATION(bias)
58#endif // defined(HAS_BIAS)
59)
60{
61 // Each thread stores a 2x2/2x1 or 1x2 tile accordingly with the filter size
62 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
63
64 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
65
66 // Load the values across the 16 or 4 channels to compose the 4x4 or 4x1 tile
67 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
68 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
69 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
70 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
71
72#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
73 // Compute the 2x1 or 1x2 output tile
74 // out00 = d00 + d01 + d02
75 // out01 = d01 - d02 - d03
76
77 float out00 = d00 + d01 + d02;
78 float out01 = d01 - d02 - d03;
79#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
80 float d10 = *((__global float *)(src_addr + 4 * src_stride_z));
81 float d11 = *((__global float *)(src_addr + 5 * src_stride_z));
82 float d12 = *((__global float *)(src_addr + 6 * src_stride_z));
83 float d13 = *((__global float *)(src_addr + 7 * src_stride_z));
84
85 float d20 = *((__global float *)(src_addr + 8 * src_stride_z));
86 float d21 = *((__global float *)(src_addr + 9 * src_stride_z));
87 float d22 = *((__global float *)(src_addr + 10 * src_stride_z));
88 float d23 = *((__global float *)(src_addr + 11 * src_stride_z));
89
90 float d30 = *((__global float *)(src_addr + 12 * src_stride_z));
91 float d31 = *((__global float *)(src_addr + 13 * src_stride_z));
92 float d32 = *((__global float *)(src_addr + 14 * src_stride_z));
93 float d33 = *((__global float *)(src_addr + 15 * src_stride_z));
94
95 // Compute the 2x2 output tile
96 float k0 = d01 + d11 + d21;
97 float k1 = d02 + d12 + d22;
98 float k2 = d11 - d21 - d31;
99 float k3 = d12 - d22 - d32;
100
101 // out00 = d00 + d10 + d20 + d01 + d11 + d21 + d02 + d12 + d22
102 // out01 = d01 + d11 + d21 - (d02 + d12 + d22) - (d03 + d13 + d23)
103 // out10 = d10 - d20 - d30 + (d11 - d21 - d31) + (d12 - d22 - d32)
104 // out11 = d11 - d21 - d31 - (d12 - d22 - d32) - (d13 - d23 - d33)
105
106 float out00 = d10;
107 float out01 = -d13;
108 float out10 = d10;
109 float out11 = -d13;
110
111 out00 += d00 + d20 + k0 + k1;
112 out01 += k0 - k1 - (d03 + d23);
113 out10 += -d20 - d30 + k2 + k3;
114 out11 += k2 - k3 + d23 + d33;
115#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
116
117 int y_in = get_global_id(1);
118 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
119 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
120 int z_out = get_global_id(0);
121
122#if defined(HAS_BIAS)
123 // Add bias
124 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
125
126 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
127
128 out00 += (float)b;
129 out01 += (float)b;
130#endif // defined(HAS_BIAS)
131
132 // Get output address
133 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z;
134
135 // Store the output tile
136#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
137 *((__global float *)(dst_addr + 0 * dst_stride_y)) = out00;
138 *((__global float *)(dst_addr + 1 * dst_stride_y)) = out01;
139#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
140 vstore2((float2)(out00, out01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
141#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
142
143#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
144#if defined(HAS_BIAS)
145 // Add bias
146 out10 += (float)b;
147 out11 += (float)b;
148#endif // defined(HAS_BIAS)
149
150 vstore2((float2)(out10, out11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
151#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
152}
153
154/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data layout is NCHW
155 *
156 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
157 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
158 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
159 * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
160 * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
161 *
162 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
163 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
164 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
165 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
166 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
167 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
168 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
169 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
170 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
171 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
172 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
173 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
174 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
175 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
176 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
177 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
178 */
179__kernel void winograd_output_transform_4x4_3x3_nchw(
180 TENSOR3D_DECLARATION(src),
181 TENSOR3D_DECLARATION(dst)
182#if defined(HAS_BIAS)
183 ,
184 VECTOR_DECLARATION(bias)
185#endif // defined(HAS_BIAS)
186)
187{
188 // Each thread stores a 4x4/4x1 or 1x4 tile
189 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
190
191 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
192
193 // Load the values across the channels to compose the 6x6 or 6x1 tile
194 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
195 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
196 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
197 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
198 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
199 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
200
201#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
202 // Compute out00, out01, out02 and out03
203 float out00 = d00 + d01 + d02 + d03 + d04;
204 float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
205 float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
206 float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
207#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
208 float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
209 float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
210 float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
211 float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
212 float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
213 float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
214
215 float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
216 float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
217 float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
218 float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
219 float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
220 float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
221
222 float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
223 float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
224 float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
225 float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
226 float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
227 float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
228
229 float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
230 float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
231 float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
232 float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
233 float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
234 float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
235
236 float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
237 float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
238 float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
239 float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
240 float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
241 float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
242
243 // Compute out00, out01, out02 and out03
244 float out00 = d01 + d21 + d41 + d11 + d31;
245 float out01 = d01 + d21 + d41 + d11 + d31;
246 float out02 = d01 + d21 + d41 + d11 + d31;
247 float out03 = d01 + d21 + d41 + d11 + d31;
248
249 float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
250 float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
251
252 out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
253 out01 += k1 - d02 - d12 - d22 - d32 - d42;
254 out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
255 out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
256
257 // Compute out10, out11, out12 and out13
258 float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
259 float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
260 float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
261 float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
262
263 k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
264 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
265
266 out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
267 out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
268 out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
269 out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
270
271 // Compute out20, out21, out22 and out23
272 float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
273 float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
274 float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
275 float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
276
277 k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
278 k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
279
280 out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
281 out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
282 out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
283 out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
284
285 // Compute out30, out31, out32 and out33
286 float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
287 float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
288 float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
289 float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
290
291 k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
292 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
293
294 out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
295 out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
296 out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
297 out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
298#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
299
300 int y_in = get_global_id(1);
301 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
302 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
303 int z_out = get_global_id(0);
304
305#if defined(HAS_BIAS)
306 // Add bias
307 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
308
309 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
310
311 out00 += (float)b;
312 out01 += (float)b;
313 out02 += (float)b;
314 out03 += (float)b;
315#endif // defined(HAS_BIAS)
316
317 // Get output address
318 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z;
319
320 // Store the output tile
321#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
322 *((__global float *)(dst_addr + 0 * dst_stride_y)) = out00;
323 *((__global float *)(dst_addr + 1 * dst_stride_y)) = out01;
324 *((__global float *)(dst_addr + 2 * dst_stride_y)) = out02;
325 *((__global float *)(dst_addr + 3 * dst_stride_y)) = out03;
326#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
327 vstore4((float4)(out00, out01, out02, out03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
328#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
329
330#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
331#if defined(HAS_BIAS)
332 // Add bias
333 out10 += (float)b;
334 out11 += (float)b;
335 out12 += (float)b;
336 out13 += (float)b;
337
338 out20 += (float)b;
339 out21 += (float)b;
340 out22 += (float)b;
341 out23 += (float)b;
342
343 out30 += (float)b;
344 out31 += (float)b;
345 out32 += (float)b;
346 out33 += (float)b;
347#endif // defined(HAS_BIAS)
348 vstore4((float4)(out10, out11, out12, out13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
349 vstore4((float4)(out20, out21, out22, out23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
350 vstore4((float4)(out30, out31, out32, out33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
351#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
352}
353
354#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
355/** This OpenCL kernel performs Winograd output transform when the output tile is 2x1, the filter size 3x1 and the data layout is NCHW
356 *
357 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
358 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
359 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
360 * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
361 *
362 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
363 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
364 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
365 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
366 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
367 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
368 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
369 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
370 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
371 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
372 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
373 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
374 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
375 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
376 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
377 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
378 */
379__kernel void winograd_output_transform_2x1_3x1_nchw(
380 TENSOR3D_DECLARATION(src),
381 TENSOR3D_DECLARATION(dst)
382#if defined(HAS_BIAS)
383 ,
384 VECTOR_DECLARATION(bias)
385#endif // defined(HAS_BIAS)
386)
387{
388 winograd_output_transform_2x2_3x3_nchw(src_ptr,
389 src_stride_x,
390 src_step_x,
391 src_stride_y,
392 src_step_y,
393 src_stride_z,
394 src_step_z,
395 src_offset_first_element_in_bytes,
396 dst_ptr,
397 dst_stride_x,
398 dst_step_x,
399 dst_stride_y,
400 dst_step_y,
401 dst_stride_z,
402 dst_step_z,
403 dst_offset_first_element_in_bytes
404#if defined(HAS_BIAS)
405 ,
406 bias_ptr,
407 bias_stride_x,
408 bias_step_x,
409 bias_offset_first_element_in_bytes
410#endif // defined(HAS_BIAS)
411 );
412}
413
414/** This OpenCL kernel performs Winograd output transform when the output tile is 4x1, the filter size 3x1 and the data layout is NCHW
415 *
416 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
417 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
418 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
419 * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
420 *
421 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
422 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
423 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
424 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
425 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
426 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
427 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
428 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
429 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
430 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
431 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
432 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
433 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
434 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
435 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
436 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
437 */
438__kernel void winograd_output_transform_4x1_3x1_nchw(
439 TENSOR3D_DECLARATION(src),
440 TENSOR3D_DECLARATION(dst)
441#if defined(HAS_BIAS)
442 ,
443 VECTOR_DECLARATION(bias)
444#endif // defined(HAS_BIAS)
445)
446{
447 winograd_output_transform_4x4_3x3_nchw(src_ptr,
448 src_stride_x,
449 src_step_x,
450 src_stride_y,
451 src_step_y,
452 src_stride_z,
453 src_step_z,
454 src_offset_first_element_in_bytes,
455 dst_ptr,
456 dst_stride_x,
457 dst_step_x,
458 dst_stride_y,
459 dst_step_y,
460 dst_stride_z,
461 dst_step_z,
462 dst_offset_first_element_in_bytes
463#if defined(HAS_BIAS)
464 ,
465 bias_ptr,
466 bias_stride_x,
467 bias_step_x,
468 bias_offset_first_element_in_bytes
469#endif // defined(HAS_BIAS)
470 );
471}
472#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
473
474#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
475/** This OpenCL kernel performs Winograd output transform when the output tile is 1x2, the filter size 1x3 and the data layout is NCHW
476 *
477 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
478 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
479 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
480 * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
481 *
482 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
483 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
484 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
485 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
486 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
487 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
488 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
489 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
490 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
491 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
492 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
493 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
494 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
495 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
496 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
497 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
498 */
499__kernel void winograd_output_transform_1x2_1x3_nchw(
500 TENSOR3D_DECLARATION(src),
501 TENSOR3D_DECLARATION(dst)
502#if defined(HAS_BIAS)
503 ,
504 VECTOR_DECLARATION(bias)
505#endif // defined(HAS_BIAS)
506)
507{
508 winograd_output_transform_2x2_3x3_nchw(src_ptr,
509 src_stride_x,
510 src_step_x,
511 src_stride_y,
512 src_step_y,
513 src_stride_z,
514 src_step_z,
515 src_offset_first_element_in_bytes,
516 dst_ptr,
517 dst_stride_x,
518 dst_step_x,
519 dst_stride_y,
520 dst_step_y,
521 dst_stride_z,
522 dst_step_z,
523 dst_offset_first_element_in_bytes
524#if defined(HAS_BIAS)
525 ,
526 bias_ptr,
527 bias_stride_x,
528 bias_step_x,
529 bias_offset_first_element_in_bytes
530#endif // defined(HAS_BIAS)
531 );
532}
533
534/** This OpenCL kernel performs Winograd output transform when the output tile is 1x4, the filter size 1x3 and the data layout is NCHW
535 *
536 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
537 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
538 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
539 * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
540 *
541 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
542 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
543 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
544 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
545 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
546 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
547 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
548 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
549 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
550 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
551 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
552 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
553 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
554 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
555 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
556 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
557 */
558__kernel void winograd_output_transform_1x4_1x3_nchw(
559 TENSOR3D_DECLARATION(src),
560 TENSOR3D_DECLARATION(dst)
561#if defined(HAS_BIAS)
562 ,
563 VECTOR_DECLARATION(bias)
564#endif // defined(HAS_BIAS)
565)
566{
567 winograd_output_transform_4x4_3x3_nchw(src_ptr,
568 src_stride_x,
569 src_step_x,
570 src_stride_y,
571 src_step_y,
572 src_stride_z,
573 src_step_z,
574 src_offset_first_element_in_bytes,
575 dst_ptr,
576 dst_stride_x,
577 dst_step_x,
578 dst_stride_y,
579 dst_step_y,
580 dst_stride_z,
581 dst_step_z,
582 dst_offset_first_element_in_bytes
583#if defined(HAS_BIAS)
584 ,
585 bias_ptr,
586 bias_stride_x,
587 bias_step_x,
588 bias_offset_first_element_in_bytes
589#endif // defined(HAS_BIAS)
590 );
591}
592#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
593
594/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data layout is NHWC
595 *
596 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
597 *
598 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
599 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
600 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
601 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
602 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
603 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
604 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
605 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
606 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
607 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
608 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
609 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
610 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
611 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
612 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
613 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
614 * @param[in] dst_size Size of the destination tensor, minus the last padding
615 */
616__kernel void winograd_output_transform_4x4_3x3_nhwc(
617 TENSOR3D_DECLARATION(src),
618 TENSOR3D_DECLARATION(dst),
619#if defined(HAS_BIAS)
620 VECTOR_DECLARATION(bias),
621#endif // defined(HAS_BIAS)
622 int dst_size)
623{
624 // Each thread stores a 4x4 tile
625 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
626
627 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
628
629 // Load the values across the 36 channels to compose the 6x6 tile
630 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
631 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
632 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
633 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
634 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
635 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
636
637 float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
638 float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
639 float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
640 float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
641 float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
642 float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
643
644 float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
645 float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
646 float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
647 float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
648 float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
649 float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
650
651 float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
652 float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
653 float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
654 float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
655 float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
656 float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
657
658 float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
659 float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
660 float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
661 float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
662 float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
663 float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
664
665 float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
666 float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
667 float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
668 float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
669 float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
670 float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
671
672 // Compute out00, out01, out02 and out03
673 float out00 = d01 + d21 + d41 + d11 + d31;
674 float out01 = d01 + d21 + d41 + d11 + d31;
675 float out02 = d01 + d21 + d41 + d11 + d31;
676 float out03 = d01 + d21 + d41 + d11 + d31;
677
678 float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
679 float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
680
681 out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
682 out01 += k1 - d02 - d12 - d22 - d32 - d42;
683 out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
684 out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
685
686 // Compute out10, out11, out12 and out13
687 float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
688 float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
689 float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
690 float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
691
692 k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
693 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
694
695 out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
696 out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
697 out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
698 out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
699
700 // Compute out20, out21, out22 and out23
701 float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
702 float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
703 float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
704 float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
705
706 k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
707 k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
708
709 out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
710 out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
711 out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
712 out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
713
714 // Compute out30, out31, out32 and out33
715 float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
716 float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
717 float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
718 float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
719
720 k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
721 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
722
723 out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
724 out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
725 out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
726 out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
727
728 int y_in = get_global_id(1);
729 int x_out = get_global_id(0);
730 int y_out = (y_in % NUM_TILES_X) * 4;
731 int z_out = (y_in / NUM_TILES_X) * 4;
732
733#if defined(HAS_BIAS)
734 // Add bias
735 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
736
737 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
738
739 out00 += (float)b;
740 out01 += (float)b;
741 out02 += (float)b;
742 out03 += (float)b;
743
744 out10 += (float)b;
745 out11 += (float)b;
746 out12 += (float)b;
747 out13 += (float)b;
748
749 out20 += (float)b;
750 out21 += (float)b;
751 out22 += (float)b;
752 out23 += (float)b;
753
754 out30 += (float)b;
755 out31 += (float)b;
756 out32 += (float)b;
757 out33 += (float)b;
758
759#endif // defined(HAS_BIAS)
760
761 // Get output address
762 int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z);
763 offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
764 int4 mult_y = min(dst_size - offset, 1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise.
765
766 // Store the 4x4 output tile
767 *((__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s0)) = out00;
768 *((__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s0)) = out01;
769 *((__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s0)) = out02;
770 *((__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s0)) = out03;
771 *((__global float *)(dst_ptr + mult_y.s1 * 0 * dst_stride_y + offset.s1)) = out10;
772 *((__global float *)(dst_ptr + mult_y.s1 * 1 * dst_stride_y + offset.s1)) = out11;
773 *((__global float *)(dst_ptr + mult_y.s1 * 2 * dst_stride_y + offset.s1)) = out12;
774 *((__global float *)(dst_ptr + mult_y.s1 * 3 * dst_stride_y + offset.s1)) = out13;
775 *((__global float *)(dst_ptr + mult_y.s2 * 0 * dst_stride_y + offset.s2)) = out20;
776 *((__global float *)(dst_ptr + mult_y.s2 * 1 * dst_stride_y + offset.s2)) = out21;
777 *((__global float *)(dst_ptr + mult_y.s2 * 2 * dst_stride_y + offset.s2)) = out22;
778 *((__global float *)(dst_ptr + mult_y.s2 * 3 * dst_stride_y + offset.s2)) = out23;
779 *((__global float *)(dst_ptr + mult_y.s3 * 0 * dst_stride_y + offset.s3)) = out30;
780 *((__global float *)(dst_ptr + mult_y.s3 * 1 * dst_stride_y + offset.s3)) = out31;
781 *((__global float *)(dst_ptr + mult_y.s3 * 2 * dst_stride_y + offset.s3)) = out32;
782 *((__global float *)(dst_ptr + mult_y.s3 * 3 * dst_stride_y + offset.s3)) = out33;
783}
784
785#define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \
786 ({ \
787 comm_fact.s0 = d1 + d2; \
788 comm_fact.s1 = d3 + d4; \
789 comm_fact.s2 = d5 + d6; \
790 \
791 col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \
792 col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \
793 \
794 comm_fact.s0 = d1 - d2; \
795 comm_fact.s1 = d3 - d4; \
796 comm_fact.s2 = d5 - d6; \
797 \
798 col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \
799 col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \
800 })
801
802/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 5x5 and the data layout is NCHW
803 *
804 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
805 *
806 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
807 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
808 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
809 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
810 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
811 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
812 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
813 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
814 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
815 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
816 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
817 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
818 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
819 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
820 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
821 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
822 */
823__kernel void winograd_output_transform_4x4_5x5_nchw(
824 TENSOR3D_DECLARATION(src),
825 TENSOR3D_DECLARATION(dst)
826#if defined(HAS_BIAS)
827 ,
828 VECTOR_DECLARATION(bias)
829#endif // defined(HAS_BIAS)
830)
831{
832 // Each thread stores a 4x4 tile
833 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
834
835 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
836
837 // Load the values across the 64 channels to compose the 8x8 input tile
838 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
839 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
840 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
841 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
842 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
843 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
844 float d06 = *((__global float *)(src_addr + 6 * src_stride_z));
845 float d07 = *((__global float *)(src_addr + 7 * src_stride_z));
846
847 float d10 = *((__global float *)(src_addr + 8 * src_stride_z));
848 float d11 = *((__global float *)(src_addr + 9 * src_stride_z));
849 float d12 = *((__global float *)(src_addr + 10 * src_stride_z));
850 float d13 = *((__global float *)(src_addr + 11 * src_stride_z));
851 float d14 = *((__global float *)(src_addr + 12 * src_stride_z));
852 float d15 = *((__global float *)(src_addr + 13 * src_stride_z));
853 float d16 = *((__global float *)(src_addr + 14 * src_stride_z));
854 float d17 = *((__global float *)(src_addr + 15 * src_stride_z));
855
856 float d20 = *((__global float *)(src_addr + 16 * src_stride_z));
857 float d21 = *((__global float *)(src_addr + 17 * src_stride_z));
858 float d22 = *((__global float *)(src_addr + 18 * src_stride_z));
859 float d23 = *((__global float *)(src_addr + 19 * src_stride_z));
860 float d24 = *((__global float *)(src_addr + 20 * src_stride_z));
861 float d25 = *((__global float *)(src_addr + 21 * src_stride_z));
862 float d26 = *((__global float *)(src_addr + 22 * src_stride_z));
863 float d27 = *((__global float *)(src_addr + 23 * src_stride_z));
864
865 float d30 = *((__global float *)(src_addr + 24 * src_stride_z));
866 float d31 = *((__global float *)(src_addr + 25 * src_stride_z));
867 float d32 = *((__global float *)(src_addr + 26 * src_stride_z));
868 float d33 = *((__global float *)(src_addr + 27 * src_stride_z));
869 float d34 = *((__global float *)(src_addr + 28 * src_stride_z));
870 float d35 = *((__global float *)(src_addr + 29 * src_stride_z));
871 float d36 = *((__global float *)(src_addr + 30 * src_stride_z));
872 float d37 = *((__global float *)(src_addr + 31 * src_stride_z));
873
874 float d40 = *((__global float *)(src_addr + 32 * src_stride_z));
875 float d41 = *((__global float *)(src_addr + 33 * src_stride_z));
876 float d42 = *((__global float *)(src_addr + 34 * src_stride_z));
877 float d43 = *((__global float *)(src_addr + 35 * src_stride_z));
878 float d44 = *((__global float *)(src_addr + 36 * src_stride_z));
879 float d45 = *((__global float *)(src_addr + 37 * src_stride_z));
880 float d46 = *((__global float *)(src_addr + 38 * src_stride_z));
881 float d47 = *((__global float *)(src_addr + 39 * src_stride_z));
882
883 float d50 = *((__global float *)(src_addr + 40 * src_stride_z));
884 float d51 = *((__global float *)(src_addr + 41 * src_stride_z));
885 float d52 = *((__global float *)(src_addr + 42 * src_stride_z));
886 float d53 = *((__global float *)(src_addr + 43 * src_stride_z));
887 float d54 = *((__global float *)(src_addr + 44 * src_stride_z));
888 float d55 = *((__global float *)(src_addr + 45 * src_stride_z));
889 float d56 = *((__global float *)(src_addr + 46 * src_stride_z));
890 float d57 = *((__global float *)(src_addr + 47 * src_stride_z));
891
892 float d60 = *((__global float *)(src_addr + 48 * src_stride_z));
893 float d61 = *((__global float *)(src_addr + 49 * src_stride_z));
894 float d62 = *((__global float *)(src_addr + 50 * src_stride_z));
895 float d63 = *((__global float *)(src_addr + 51 * src_stride_z));
896 float d64 = *((__global float *)(src_addr + 52 * src_stride_z));
897 float d65 = *((__global float *)(src_addr + 53 * src_stride_z));
898 float d66 = *((__global float *)(src_addr + 54 * src_stride_z));
899 float d67 = *((__global float *)(src_addr + 55 * src_stride_z));
900
901 float d70 = *((__global float *)(src_addr + 56 * src_stride_z));
902 float d71 = *((__global float *)(src_addr + 57 * src_stride_z));
903 float d72 = *((__global float *)(src_addr + 58 * src_stride_z));
904 float d73 = *((__global float *)(src_addr + 59 * src_stride_z));
905 float d74 = *((__global float *)(src_addr + 60 * src_stride_z));
906 float d75 = *((__global float *)(src_addr + 61 * src_stride_z));
907 float d76 = *((__global float *)(src_addr + 62 * src_stride_z));
908 float d77 = *((__global float *)(src_addr + 63 * src_stride_z));
909
910 // Compute the 8x4 intermediate tensor
911 float4 comm_fact0, comm_fact1, comm_fact2;
912 float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
913
914 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
915 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
916 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
917 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
918 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
919 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
920 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
921 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
922
923 // Compute the 4x4 output tile
924 comm_fact0 = tmp_col1 + tmp_col2;
925 comm_fact1 = tmp_col3 + tmp_col4;
926 comm_fact2 = tmp_col5 + tmp_col6;
927
928 float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
929 float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
930
931 comm_fact0 = tmp_col1 - tmp_col2;
932 comm_fact1 = tmp_col3 - tmp_col4;
933 comm_fact2 = tmp_col5 - tmp_col6;
934
935 float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
936 float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
937
938 int y_in = get_global_id(1);
939 int x_out = (y_in % NUM_TILES_X) * 4;
940 int y_out = (y_in / NUM_TILES_X) * 4;
941 int z_out = get_global_id(0);
942
943#if defined(HAS_BIAS)
944 // Add bias
945 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
946
947 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
948
949 out_col0 += (float4)b;
950 out_col1 += (float4)b;
951 out_col2 += (float4)b;
952 out_col3 += (float4)b;
953#endif // defined(HAS_BIAS)
954
955 // Get output address
956 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
957
958 // Store the 4x4 output tile
959 *(__global float *)(dst_addr + 0 * dst_stride_x + 0 * dst_stride_y) = out_col0.s0;
960 *(__global float *)(dst_addr + 1 * dst_stride_x + 0 * dst_stride_y) = out_col1.s0;
961 *(__global float *)(dst_addr + 2 * dst_stride_x + 0 * dst_stride_y) = out_col2.s0;
962 *(__global float *)(dst_addr + 3 * dst_stride_x + 0 * dst_stride_y) = out_col3.s0;
963 *(__global float *)(dst_addr + 0 * dst_stride_x + 1 * dst_stride_y) = out_col0.s1;
964 *(__global float *)(dst_addr + 1 * dst_stride_x + 1 * dst_stride_y) = out_col1.s1;
965 *(__global float *)(dst_addr + 2 * dst_stride_x + 1 * dst_stride_y) = out_col2.s1;
966 *(__global float *)(dst_addr + 3 * dst_stride_x + 1 * dst_stride_y) = out_col3.s1;
967 *(__global float *)(dst_addr + 0 * dst_stride_x + 2 * dst_stride_y) = out_col0.s2;
968 *(__global float *)(dst_addr + 1 * dst_stride_x + 2 * dst_stride_y) = out_col1.s2;
969 *(__global float *)(dst_addr + 2 * dst_stride_x + 2 * dst_stride_y) = out_col2.s2;
970 *(__global float *)(dst_addr + 3 * dst_stride_x + 2 * dst_stride_y) = out_col3.s2;
971 *(__global float *)(dst_addr + 0 * dst_stride_x + 3 * dst_stride_y) = out_col0.s3;
972 *(__global float *)(dst_addr + 1 * dst_stride_x + 3 * dst_stride_y) = out_col1.s3;
973 *(__global float *)(dst_addr + 2 * dst_stride_x + 3 * dst_stride_y) = out_col2.s3;
974 *(__global float *)(dst_addr + 3 * dst_stride_x + 3 * dst_stride_y) = out_col3.s3;
975}
976
977/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 5x5 and the data layout is NHWC
978 *
979 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
980 *
981 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
982 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
983 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
984 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
985 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
986 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
987 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
988 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
989 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
990 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
991 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
992 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
993 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
994 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
995 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
996 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
997 */
998__kernel void winograd_output_transform_4x4_5x5_nhwc(
999 TENSOR3D_DECLARATION(src),
1000 TENSOR3D_DECLARATION(dst),
1001#if defined(HAS_BIAS)
1002 VECTOR_DECLARATION(bias),
1003#endif // defined(HAS_BIAS)
1004 int dst_size)
1005{
1006 // Each thread stores a 4x4 tile
1007 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
1008
1009 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
1010
1011 // Load the values across the 64 channels to compose the 8x8 input tile
1012 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
1013 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
1014 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
1015 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
1016 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
1017 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
1018 float d06 = *((__global float *)(src_addr + 6 * src_stride_z));
1019 float d07 = *((__global float *)(src_addr + 7 * src_stride_z));
1020
1021 float d10 = *((__global float *)(src_addr + 8 * src_stride_z));
1022 float d11 = *((__global float *)(src_addr + 9 * src_stride_z));
1023 float d12 = *((__global float *)(src_addr + 10 * src_stride_z));
1024 float d13 = *((__global float *)(src_addr + 11 * src_stride_z));
1025 float d14 = *((__global float *)(src_addr + 12 * src_stride_z));
1026 float d15 = *((__global float *)(src_addr + 13 * src_stride_z));
1027 float d16 = *((__global float *)(src_addr + 14 * src_stride_z));
1028 float d17 = *((__global float *)(src_addr + 15 * src_stride_z));
1029
1030 float d20 = *((__global float *)(src_addr + 16 * src_stride_z));
1031 float d21 = *((__global float *)(src_addr + 17 * src_stride_z));
1032 float d22 = *((__global float *)(src_addr + 18 * src_stride_z));
1033 float d23 = *((__global float *)(src_addr + 19 * src_stride_z));
1034 float d24 = *((__global float *)(src_addr + 20 * src_stride_z));
1035 float d25 = *((__global float *)(src_addr + 21 * src_stride_z));
1036 float d26 = *((__global float *)(src_addr + 22 * src_stride_z));
1037 float d27 = *((__global float *)(src_addr + 23 * src_stride_z));
1038
1039 float d30 = *((__global float *)(src_addr + 24 * src_stride_z));
1040 float d31 = *((__global float *)(src_addr + 25 * src_stride_z));
1041 float d32 = *((__global float *)(src_addr + 26 * src_stride_z));
1042 float d33 = *((__global float *)(src_addr + 27 * src_stride_z));
1043 float d34 = *((__global float *)(src_addr + 28 * src_stride_z));
1044 float d35 = *((__global float *)(src_addr + 29 * src_stride_z));
1045 float d36 = *((__global float *)(src_addr + 30 * src_stride_z));
1046 float d37 = *((__global float *)(src_addr + 31 * src_stride_z));
1047
1048 float d40 = *((__global float *)(src_addr + 32 * src_stride_z));
1049 float d41 = *((__global float *)(src_addr + 33 * src_stride_z));
1050 float d42 = *((__global float *)(src_addr + 34 * src_stride_z));
1051 float d43 = *((__global float *)(src_addr + 35 * src_stride_z));
1052 float d44 = *((__global float *)(src_addr + 36 * src_stride_z));
1053 float d45 = *((__global float *)(src_addr + 37 * src_stride_z));
1054 float d46 = *((__global float *)(src_addr + 38 * src_stride_z));
1055 float d47 = *((__global float *)(src_addr + 39 * src_stride_z));
1056
1057 float d50 = *((__global float *)(src_addr + 40 * src_stride_z));
1058 float d51 = *((__global float *)(src_addr + 41 * src_stride_z));
1059 float d52 = *((__global float *)(src_addr + 42 * src_stride_z));
1060 float d53 = *((__global float *)(src_addr + 43 * src_stride_z));
1061 float d54 = *((__global float *)(src_addr + 44 * src_stride_z));
1062 float d55 = *((__global float *)(src_addr + 45 * src_stride_z));
1063 float d56 = *((__global float *)(src_addr + 46 * src_stride_z));
1064 float d57 = *((__global float *)(src_addr + 47 * src_stride_z));
1065
1066 float d60 = *((__global float *)(src_addr + 48 * src_stride_z));
1067 float d61 = *((__global float *)(src_addr + 49 * src_stride_z));
1068 float d62 = *((__global float *)(src_addr + 50 * src_stride_z));
1069 float d63 = *((__global float *)(src_addr + 51 * src_stride_z));
1070 float d64 = *((__global float *)(src_addr + 52 * src_stride_z));
1071 float d65 = *((__global float *)(src_addr + 53 * src_stride_z));
1072 float d66 = *((__global float *)(src_addr + 54 * src_stride_z));
1073 float d67 = *((__global float *)(src_addr + 55 * src_stride_z));
1074
1075 float d70 = *((__global float *)(src_addr + 56 * src_stride_z));
1076 float d71 = *((__global float *)(src_addr + 57 * src_stride_z));
1077 float d72 = *((__global float *)(src_addr + 58 * src_stride_z));
1078 float d73 = *((__global float *)(src_addr + 59 * src_stride_z));
1079 float d74 = *((__global float *)(src_addr + 60 * src_stride_z));
1080 float d75 = *((__global float *)(src_addr + 61 * src_stride_z));
1081 float d76 = *((__global float *)(src_addr + 62 * src_stride_z));
1082 float d77 = *((__global float *)(src_addr + 63 * src_stride_z));
1083
1084 // Compute the 8x4 intermediate tensor
1085 float4 comm_fact0, comm_fact1, comm_fact2;
1086 float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
1087
1088 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
1089 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
1090 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
1091 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
1092 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
1093 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
1094 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
1095 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
1096
1097 // Compute the 4x4 output tile
1098 comm_fact0 = tmp_col1 + tmp_col2;
1099 comm_fact1 = tmp_col3 + tmp_col4;
1100 comm_fact2 = tmp_col5 + tmp_col6;
1101
1102 float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
1103 float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
1104
1105 comm_fact0 = tmp_col1 - tmp_col2;
1106 comm_fact1 = tmp_col3 - tmp_col4;
1107 comm_fact2 = tmp_col5 - tmp_col6;
1108
1109 float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
1110 float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
1111
1112 int y_in = get_global_id(1);
1113 int x_out = get_global_id(0);
1114 int y_out = (y_in % NUM_TILES_X) * 4;
1115 int z_out = (y_in / NUM_TILES_X) * 4;
1116
1117#if defined(HAS_BIAS)
1118 // Add bias
1119 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
1120
1121 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
1122
1123 out_col0 += (float4)b;
1124 out_col1 += (float4)b;
1125 out_col2 += (float4)b;
1126 out_col3 += (float4)b;
1127#endif // defined(HAS_BIAS)
1128
1129 // Get output address
1130 int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z);
1131 offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
1132 int4 mult_y = min(dst_size - offset, 1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise.
1133
1134 // Store the 4x4 output tile
1135 *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s0) = out_col0.s0;
1136 *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s0) = out_col1.s0;
1137 *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s0) = out_col2.s0;
1138 *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s0) = out_col3.s0;
1139 *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s1) = out_col0.s1;
1140 *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s1) = out_col1.s1;
1141 *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s1) = out_col2.s1;
1142 *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s1) = out_col3.s1;
1143 *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s2) = out_col0.s2;
1144 *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s2) = out_col1.s2;
1145 *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s2) = out_col2.s2;
1146 *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s2) = out_col3.s2;
1147 *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s3) = out_col0.s3;
1148 *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s3) = out_col1.s3;
1149 *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s3) = out_col2.s3;
1150 *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s3) = out_col3.s3;
1151}
1152#endif // defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)