blob: 4662426a72635f587b4584468528df2d872dd32d [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
27/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
28 *
29 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
30 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
31 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
32 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
33 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
34 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
35 *
36 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
37 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
38 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
39 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
40 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
41 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
42 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
43 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
44 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
45 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
46 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
47 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
48 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
49 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
50 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
51 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
52 */
53__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
54 TENSOR3D_DECLARATION(src),
55 TENSOR3D_DECLARATION(dst))
56{
57 int x = get_global_id(0);
58 int y = get_global_id(1);
59 int z = get_global_id(2);
60
61 // Compute input address
62 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
63
64 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
65
66#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
67 float4 in_row0 = vload4(0, (__global float *)(src_addr));
68#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
69 float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
70 *((__global float *)(src_addr + 1 * src_stride_y)),
71 *((__global float *)(src_addr + 2 * src_stride_y)),
72 *((__global float *)(src_addr + 3 * src_stride_y)));
73#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
74 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
75 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
76 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
77 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
78#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
79
80 float4 tmp0 = in_row0;
81
82#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
83 tmp0 -= in_row2;
84#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
85
86 float out00 = tmp0.s0 - tmp0.s2;
87 float out01 = tmp0.s1 + tmp0.s2;
88 float out02 = tmp0.s2 - tmp0.s1;
89 float out03 = tmp0.s1 - tmp0.s3;
90
91#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
92 float4 tmp1 = in_row1 + in_row2;
93 float4 tmp2 = in_row2 - in_row1;
94 float4 tmp3 = in_row1 - in_row3;
95
96 float out10 = tmp1.s0 - tmp1.s2;
97 float out11 = tmp1.s1 + tmp1.s2;
98 float out12 = tmp1.s2 - tmp1.s1;
99 float out13 = tmp1.s1 - tmp1.s3;
100
101 float out20 = tmp2.s0 - tmp2.s2;
102 float out21 = tmp2.s1 + tmp2.s2;
103 float out22 = tmp2.s2 - tmp2.s1;
104 float out23 = tmp2.s1 - tmp2.s3;
105
106 float out30 = tmp3.s0 - tmp3.s2;
107 float out31 = tmp3.s1 + tmp3.s2;
108 float out32 = tmp3.s2 - tmp3.s1;
109 float out33 = tmp3.s1 - tmp3.s3;
110#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
111
112 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
113
114 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
115 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
116 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
117 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
118
119#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
120 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out10;
121 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out11;
122 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out12;
123 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out13;
124 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out20;
125 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out21;
126 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out22;
127 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out23;
128 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out30;
129 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out31;
130 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out32;
131 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out33;
132#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
133}
134
135/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
136 *
137 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
138 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
139 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
140 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
141 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
142 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
143 *
144 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
145 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
146 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
147 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
148 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
149 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
150 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
151 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
152 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
153 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
154 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
155 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
156 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
157 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
158 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
159 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
160 */
161__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
162 TENSOR3D_DECLARATION(src),
163 TENSOR3D_DECLARATION(dst))
164{
165 int x = get_global_id(0);
166 int y = get_global_id(1);
167 int z = get_global_id(2) * 2;
168
169 // Compute input address
170 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
171
172 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
173
174#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
175 float4 in_row0 = vload4(0, (__global float *)(src_addr));
176#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
177 float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
178 *((__global float *)(src_addr + 1 * src_stride_y)),
179 *((__global float *)(src_addr + 2 * src_stride_y)),
180 *((__global float *)(src_addr + 3 * src_stride_y)));
181#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
182 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
183 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
184 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
185 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
186#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
187
188 src_addr += src_stride_z;
189#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
190 float4 in_row4 = vload4(0, (__global float *)(src_addr));
191#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
192 float4 in_row4 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
193 *((__global float *)(src_addr + 1 * src_stride_y)),
194 *((__global float *)(src_addr + 2 * src_stride_y)),
195 *((__global float *)(src_addr + 3 * src_stride_y)));
196#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
197 float4 in_row4 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
198 float4 in_row5 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
199 float4 in_row6 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
200 float4 in_row7 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
201#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
202
203 float4 tmp0 = in_row0;
204 float4 tmp4 = in_row4;
205
206#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
207 tmp0 -= in_row2;
208 tmp4 -= in_row6;
209#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
210
211 float2 out00 = (float2)(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
212 float2 out01 = (float2)(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
213 float2 out02 = (float2)(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
214 float2 out03 = (float2)(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
215
216#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
217 float4 tmp1 = in_row1 + in_row2;
218 float4 tmp2 = in_row2 - in_row1;
219 float4 tmp3 = in_row1 - in_row3;
220
221 float4 tmp5 = in_row5 + in_row6;
222 float4 tmp6 = in_row6 - in_row5;
223 float4 tmp7 = in_row5 - in_row7;
224
225 float2 out10 = (float2)(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
226 float2 out11 = (float2)(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
227 float2 out12 = (float2)(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
228 float2 out13 = (float2)(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
229
230 float2 out20 = (float2)(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
231 float2 out21 = (float2)(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
232 float2 out22 = (float2)(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
233 float2 out23 = (float2)(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
234
235 float2 out30 = (float2)(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
236 float2 out31 = (float2)(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
237 float2 out32 = (float2)(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
238 float2 out33 = (float2)(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
239#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
240
241 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
242
243 vstore2(out00, 0, (__global float *)(dst_addr + 0 * dst_stride_z));
244 vstore2(out01, 0, (__global float *)(dst_addr + 1 * dst_stride_z));
245 vstore2(out02, 0, (__global float *)(dst_addr + 2 * dst_stride_z));
246 vstore2(out03, 0, (__global float *)(dst_addr + 3 * dst_stride_z));
247
248#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
249 vstore2(out10, 0, (__global float *)(dst_addr + 4 * dst_stride_z));
250 vstore2(out11, 0, (__global float *)(dst_addr + 5 * dst_stride_z));
251 vstore2(out12, 0, (__global float *)(dst_addr + 6 * dst_stride_z));
252 vstore2(out13, 0, (__global float *)(dst_addr + 7 * dst_stride_z));
253 vstore2(out20, 0, (__global float *)(dst_addr + 8 * dst_stride_z));
254 vstore2(out21, 0, (__global float *)(dst_addr + 9 * dst_stride_z));
255 vstore2(out22, 0, (__global float *)(dst_addr + 10 * dst_stride_z));
256 vstore2(out23, 0, (__global float *)(dst_addr + 11 * dst_stride_z));
257 vstore2(out30, 0, (__global float *)(dst_addr + 12 * dst_stride_z));
258 vstore2(out31, 0, (__global float *)(dst_addr + 13 * dst_stride_z));
259 vstore2(out32, 0, (__global float *)(dst_addr + 14 * dst_stride_z));
260 vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z));
261#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
262}
263
264/** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data layout is NCHW
265 *
266 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
267 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
268 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
269 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
270 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
271 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
272 *
273 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
274 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
275 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
276 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
277 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
278 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
279 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
280 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
281 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
282 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
283 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
284 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
285 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
286 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
287 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
288 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
289 */
290__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
291 TENSOR3D_DECLARATION(src),
292 TENSOR3D_DECLARATION(dst))
293{
294 int x = get_global_id(0);
295 int y = get_global_id(1);
296 int z = get_global_id(2);
297
298 // Compute input address
299 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
300
301 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
302
303#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
304 // Row0
305 float4 d00 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
306 *((__global float *)(src_addr + 1 * src_stride_y)),
307 *((__global float *)(src_addr + 2 * src_stride_y)),
308 *((__global float *)(src_addr + 3 * src_stride_y)));
309 float2 d01 = (float2)(*((__global float *)(src_addr + 4 * src_stride_y)),
310 *((__global float *)(src_addr + 5 * src_stride_y)));
311#else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
312 // Row0
313 float4 d00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
314 float2 d01 = vload2(2, (__global float *)(src_addr + 0 * src_stride_y));
315#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
316
317 float out0 = 0.0f;
318 float out1 = 0.0f;
319 float out2 = 0.0f;
320 float out3 = 0.0f;
321 float out4 = 0.0f;
322 float out5 = 0.0f;
323
324 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
325 out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
326 out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
327 out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
328 out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
329 out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
330 out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
331
332#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
333 // Row4
334 float4 d40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
335 float2 d41 = vload2(2, (__global float *)(src_addr + 4 * src_stride_y));
336
337 // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
338 float k0 = d41.s0;
339 float k1 = d41.s0;
340 float k2 = d41.s0;
341 float k3 = d41.s0;
342 float k4 = d41.s0;
343 float k5 = 0.0f;
344
345 k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
346 k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
347 k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
348 k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
349 k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
350 k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
351
352 out0 += k0;
353 out1 += k1;
354 out2 += k2;
355 out3 += k3;
356 out4 += k4;
357 out5 += k5;
358
359 // Row2
360 float4 d20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
361 float2 d21 = vload2(2, (__global float *)(src_addr + 2 * src_stride_y));
362
363 out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
364 out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
365 out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
366 out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
367 out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
368 out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
369#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
370
371 // Compute destination address
372 __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
373
374 uint dst_plane_stride = dst_stride_z / sizeof(float);
375
376 *(dst_addr) = out0;
377 dst_addr += dst_plane_stride;
378 *(dst_addr) = out1;
379 dst_addr += dst_plane_stride;
380 *(dst_addr) = out2;
381 dst_addr += dst_plane_stride;
382 *(dst_addr) = out3;
383 dst_addr += dst_plane_stride;
384 *(dst_addr) = out4;
385 dst_addr += dst_plane_stride;
386 *(dst_addr) = out5;
387 dst_addr += dst_plane_stride;
388
389#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
390 float out6 = k0;
391 float out7 = k1;
392 float out8 = k2;
393 float out9 = k3;
394 float out10 = k4;
395 float out11 = k5;
396 float out12 = k0;
397 float out13 = k1;
398 float out14 = k2;
399 float out15 = k3;
400 float out16 = k4;
401 float out17 = k5;
402 float out18 = k0;
403 float out19 = k1;
404 float out20 = k2;
405 float out21 = k3;
406 float out22 = k4;
407 float out23 = k5;
408 float out24 = k0;
409 float out25 = k1;
410 float out26 = k2;
411 float out27 = k3;
412 float out28 = k4;
413 float out29 = k5;
414
415 // Row1
416 float4 d10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
417 float2 d11 = vload2(2, (__global float *)(src_addr + 1 * src_stride_y));
418
419 // Row3
420 float4 d30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
421 float2 d31 = vload2(2, (__global float *)(src_addr + 3 * src_stride_y));
422
423 // Compute common parts for the channels between [6, 29]
424 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
425 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
426 float part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
427 float part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
428 float part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
429 float part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
430 float part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
431 float part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
432 float part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
433 float part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
434 float part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
435 float part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
436 float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
437 float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
438
439 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
440 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
441 float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
442 float part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
443 float part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
444 float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
445 float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
446 float part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
447 float part18 = part6 * 0.25f; // d20.s2 - d21.s0
448 float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
449 float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
450 float part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
451 float part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
452 float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
453
454 out6 += part0 - part1;
455 out12 += part0 + part1;
456 out7 += part2 + part3 + part4 + part5;
457 out8 += part2 - part3 + part4 - part5;
458 out13 += part2 + part3 - part4 - part5;
459 out14 += part2 - part3 - part4 + part5;
460 out9 += part6 + part7 + part8 + part9;
461 out10 += part6 - part7 + part8 - part9;
462 out15 += part6 - part7 - part8 + part9;
463 out16 += part6 + part7 - part8 - part9;
464 out11 += part10 + part11;
465 out17 += part10 - part11;
466
467 out18 += part13 - part12;
468 out24 += part13 + part12;
469 out19 += part14 + part15 + part16 + part17;
470 out20 += part14 - part15 + part16 - part17;
471 out25 += part14 - part15 - part16 + part17;
472 out26 += part14 + part15 - part16 - part17;
473 out21 += part18 + part19 + part20 + part21;
474 out22 += part18 - part19 + part20 - part21;
475 out27 += part18 - part19 - part20 + part21;
476 out28 += part18 + part19 - part20 - part21;
477 out23 += part22 + part23;
478 out29 += part22 - part23;
479
480 *(dst_addr) = out6;
481 dst_addr += dst_plane_stride;
482 *(dst_addr) = out7;
483 dst_addr += dst_plane_stride;
484 *(dst_addr) = out8;
485 dst_addr += dst_plane_stride;
486 *(dst_addr) = out9;
487 dst_addr += dst_plane_stride;
488 *(dst_addr) = out10;
489 dst_addr += dst_plane_stride;
490 *(dst_addr) = out11;
491 dst_addr += dst_plane_stride;
492 *(dst_addr) = out12;
493 dst_addr += dst_plane_stride;
494 *(dst_addr) = out13;
495 dst_addr += dst_plane_stride;
496 *(dst_addr) = out14;
497 dst_addr += dst_plane_stride;
498 *(dst_addr) = out15;
499 dst_addr += dst_plane_stride;
500 *(dst_addr) = out16;
501 dst_addr += dst_plane_stride;
502 *(dst_addr) = out17;
503 dst_addr += dst_plane_stride;
504
505 *(dst_addr) = out18;
506 dst_addr += dst_plane_stride;
507 *(dst_addr) = out19;
508 dst_addr += dst_plane_stride;
509 *(dst_addr) = out20;
510 dst_addr += dst_plane_stride;
511 *(dst_addr) = out21;
512 dst_addr += dst_plane_stride;
513 *(dst_addr) = out22;
514 dst_addr += dst_plane_stride;
515 *(dst_addr) = out23;
516 dst_addr += dst_plane_stride;
517 *(dst_addr) = out24;
518 dst_addr += dst_plane_stride;
519 *(dst_addr) = out25;
520 dst_addr += dst_plane_stride;
521 *(dst_addr) = out26;
522 dst_addr += dst_plane_stride;
523 *(dst_addr) = out27;
524 dst_addr += dst_plane_stride;
525 *(dst_addr) = out28;
526 dst_addr += dst_plane_stride;
527 *(dst_addr) = out29;
528 dst_addr += dst_plane_stride;
529
530 // Row5
531 float4 d50 = vload4(0, (__global float *)(src_addr + 5 * src_stride_y));
532 float2 d51 = vload2(2, (__global float *)(src_addr + 5 * src_stride_y));
533
534 // Channels [30, 35]
535 out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
536 out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
537 out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
538 out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
539 out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
540 out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
541
542 *(dst_addr) = out0;
543 dst_addr += dst_plane_stride;
544 *(dst_addr) = out1;
545 dst_addr += dst_plane_stride;
546 *(dst_addr) = out2;
547 dst_addr += dst_plane_stride;
548 *(dst_addr) = out3;
549 dst_addr += dst_plane_stride;
550 *(dst_addr) = out4;
551 dst_addr += dst_plane_stride;
552 *(dst_addr) = out5;
553 dst_addr += dst_plane_stride;
554#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
555}
556
557#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
558/** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data layout is NHWC
559 *
560 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
561 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
562 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
563 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
564 *
565 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
566 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
567 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
568 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
569 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
570 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
571 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
572 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
573 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
574 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
575 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
576 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
577 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
578 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
579 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
580 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
581 */
582__kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
583 TENSOR3D_DECLARATION(src),
584 TENSOR3D_DECLARATION(dst))
585{
586 int x = get_global_id(0);
587 int y = get_global_id(1);
588 int z = get_global_id(2);
589
590 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * src_stride_x;
591
592 // Clamp coordinates. This clamp is valid for all rows
593 int4 y_coord0 = (int4)(y * 4) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT;
594 int2 y_coord1 = (int2)(y * 4) + (int2)(4, 5) - (int2)PAD_LEFT;
595 y_coord0 = clamp(y_coord0, -1, SRC_DIM_1);
596 y_coord1 = clamp(y_coord1, -1, SRC_DIM_1);
597
598 // Row4
599 int z_coord = (z * 4) - PAD_TOP + 4;
600
601 // If z < 0, set y to -1
602 int4 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
603 int2 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
604 // If z >= SRC_DIM_2, set y to SRC_DIM_2
605 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
606 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
607
608 // Clamp z coordinate
609 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
610
611 float d40 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
612 float d41 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
613 float d42 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
614 float d43 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
615 float d44 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
616 float d45 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
617
618 float k0 = d44;
619 float k1 = d44;
620 float k2 = d44;
621 float k3 = d44;
622 float k4 = d44;
623 float k5 = (float)0.0f;
624
625 k0 += 4.0f * d40 - 5.0f * d42;
626 k1 += -4.0f * d41 - 4.0f * d42 + d43;
627 k2 += 4.0f * d41 - 4.0f * d42 - d43;
628 k3 += -2.0f * d41 + 2.0f * d43 - d42;
629 k4 += 2.0f * d41 - 2.0f * d43 - d42;
630 k5 += 4.0f * d41 - 5.0f * d43 + d45;
631
632 // Row0
633 z_coord = (z * 4) - PAD_TOP + 0;
634
635#if PAD_TOP != 0
636 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
637 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
638 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
639 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
640 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
641#else // PAD_TOP != 0
642 valid_y0 = y_coord0;
643 valid_y1 = y_coord1;
644#endif // if PAD_TOP == 0, we cannot read out of bound
645
646 float d00 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
647 float d01 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
648 float d02 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
649 float d03 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
650 float d04 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
651 float d05 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
652
653 // Row2
654 z_coord = (z * 4) - PAD_TOP + 2;
655 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
656 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
657 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
658 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
659 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
660
661 float d20 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
662 float d21 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
663 float d22 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
664 float d23 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
665 float d24 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
666 float d25 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
667
668 // Compute destination address
669 __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + x * dst_stride_x + (y + z * (int)NUM_TILES_X) * dst_stride_y);
670
671 uint dst_plane_stride = dst_stride_z / sizeof(float);
672
673 float out0 = k0;
674 float out1 = k1;
675 float out2 = k2;
676 float out3 = k3;
677 float out4 = k4;
678 float out5 = k5;
679 float out6 = k0;
680 float out7 = k1;
681 float out8 = k2;
682 float out9 = k3;
683 float out10 = k4;
684 float out11 = k5;
685 float out12 = k0;
686 float out13 = k1;
687 float out14 = k2;
688 float out15 = k3;
689 float out16 = k4;
690 float out17 = k5;
691 float out18 = k0;
692 float out19 = k1;
693 float out20 = k2;
694 float out21 = k3;
695 float out22 = k4;
696 float out23 = k5;
697 float out24 = k0;
698 float out25 = k1;
699 float out26 = k2;
700 float out27 = k3;
701 float out28 = k4;
702 float out29 = k5;
703
704 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
705 out0 += 16.0f * d00 - 20.0f * d02 - 20.0f * d20 + 25.0f * d22 + 4.0f * d04 - 5.0f * d24;
706 out1 += -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 20.0f * d21 + 20.0f * d22 - 5.0f * d23 + 4.0f * d04 - 5.0f * d24;
707 out2 += 16.0f * d01 - 16.0f * d02 - 4.0f * d03 - 20.0f * d21 + 20.0f * d22 + 5.0f * d23 + 4.0f * d04 - 5.0f * d24;
708 out3 += -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 10.0f * d21 + 5.0f * d22 - 10.0f * d23 + 4.0f * d04 - 5.0f * d24;
709 out4 += 8.0f * d01 - 4.0f * d02 - 8.0f * d03 - 10.0f * d21 + 5.0f * d22 + 10.0f * d23 + 4.0f * d04 - 5.0f * d24;
710 out5 += 16.0f * d01 - 20.0f * d03 - 20.0f * d21 + 4.0f * d05 + 25.0f * d23 - 5.0f * d25;
711
712 *((__global float *)dst_addr) = out0;
713 dst_addr += dst_plane_stride;
714 *((__global float *)dst_addr) = out1;
715 dst_addr += dst_plane_stride;
716 *((__global float *)dst_addr) = out2;
717 dst_addr += dst_plane_stride;
718 *((__global float *)dst_addr) = out3;
719 dst_addr += dst_plane_stride;
720 *((__global float *)dst_addr) = out4;
721 dst_addr += dst_plane_stride;
722 *((__global float *)dst_addr) = out5;
723 dst_addr += dst_plane_stride;
724
725 // Row1
726 z_coord = (z * 4) - PAD_TOP + 1;
727 // Row1 can never be out of bounds
728 valid_y0 = y_coord0;
729 valid_y1 = y_coord1;
730
731 float d10 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
732 float d11 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
733 float d12 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
734 float d13 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
735 float d14 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
736 float d15 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
737
738 // Row3
739 z_coord = (z * 4) - PAD_TOP + 3;
740 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
741 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
742 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
743 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
744 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
745 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
746
747 float d30 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
748 float d31 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
749 float d32 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
750 float d33 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
751 float d34 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
752 float d35 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
753
754 // Compute common parts for the channels between [6, 29]
755 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
756 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
757 float part0 = -16.0f * d20 + 20.0f * d22 - 4.0f * d24;
758 float part1 = 16.0f * d10 - 20.0f * d12 + 4.0f * d14 - 4.0f * d30 + 5.0f * d32 - d34;
759 float part2 = 16.0f * d22 - 4.0f * d24;
760 float part3 = 16.0f * d21 - 4.0f * d23;
761 float part4 = 16.0f * d12 - 4.0f * d14 - 4.0f * d32 + d34;
762 float part5 = 16.0f * d11 - 4.0f * d13 - 4.0f * d31 + d33;
763 float part6 = 4.0f * d22 - 4.0f * d24;
764 float part7 = 8.0f * d11 - 8.0f * d13 - 2.0f * d31 + 2.0f * d33;
765 float part8 = 4.0f * d12 - 4.0f * d14 - d32 + d34;
766 float part9 = 8.0f * d21 - 8.0f * d23;
767 float part10 = -16.0f * d21 + 20.0f * d23 - 4.0f * d25;
768 float part11 = -16.0f * d11 + 20.0f * d13 - 4.0f * d15 + 4.0f * d31 - 5.0f * d33 + d35;
769
770 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
771 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
772 float part12 = 8.0f * d10 - 10.0f * d12 + 2.0f * d14 - 8.0f * d30 + 10.0f * d32 - 2.0f * d34;
773 float part13 = part0 * 0.25f; // -4.0f * d20 + 5.0f * d22 - d24
774 float part14 = part2 * 0.25f; // 4.0f * d22 - d24
775 float part15 = 8.0f * d11 - 2.0f * d13 - 8.0f * d31 + 2.0f * d33;
776 float part16 = 8.0f * d12 - 2.0f * d14 - 8.0f * d32 + 2.0f * d34;
777 float part17 = part3 * 0.25f; // 4.0f * d21 - d23
778 float part18 = part6 * 0.25f; // d22 - d24
779 float part19 = 4.0f * d11 - 4.0f * d13 - 4.0f * d31 + 4.0f * d33;
780 float part20 = 2.0f * d12 - 2.0f * d14 - 2.0f * d32 + 2.0f * d34;
781 float part21 = part9 * 0.25f; // 2.0f * (d21 - d23)
782 float part22 = part10 * 0.25f; // - 4.0f * d21 + 5.0f * d23 - d25
783 float part23 = part11 * 0.5f + 6.0f * d31 - 7.5f * d33 + 1.5f * d35; // - 8.0f * d11 + 10.0f * d13 - 2.0f * d15 + 8.0f * d31 - 10.0f * d33 + 2.0f * d35;
784
785 out6 += part0 - part1;
786 out12 += part0 + part1;
787 out7 += part2 + part3 + part4 + part5;
788 out8 += part2 - part3 + part4 - part5;
789 out13 += part2 + part3 - part4 - part5;
790 out14 += part2 - part3 - part4 + part5;
791 out9 += part6 + part7 + part8 + part9;
792 out10 += part6 - part7 + part8 - part9;
793 out15 += part6 - part7 - part8 + part9;
794 out16 += part6 + part7 - part8 - part9;
795 out11 += part10 + part11;
796 out17 += part10 - part11;
797
798 out18 += part13 - part12;
799 out24 += part13 + part12;
800 out19 += part14 + part15 + part16 + part17;
801 out20 += part14 - part15 + part16 - part17;
802 out25 += part14 - part15 - part16 + part17;
803 out26 += part14 + part15 - part16 - part17;
804 out21 += part18 + part19 + part20 + part21;
805 out22 += part18 - part19 + part20 - part21;
806 out27 += part18 - part19 - part20 + part21;
807 out28 += part18 + part19 - part20 - part21;
808 out23 += part22 + part23;
809 out29 += part22 - part23;
810
811 *((__global float *)dst_addr) = out6;
812 dst_addr += dst_plane_stride;
813 *((__global float *)dst_addr) = out7;
814 dst_addr += dst_plane_stride;
815 *((__global float *)dst_addr) = out8;
816 dst_addr += dst_plane_stride;
817 *((__global float *)dst_addr) = out9;
818 dst_addr += dst_plane_stride;
819 *((__global float *)dst_addr) = out10;
820 dst_addr += dst_plane_stride;
821 *((__global float *)dst_addr) = out11;
822 dst_addr += dst_plane_stride;
823 *((__global float *)dst_addr) = out12;
824 dst_addr += dst_plane_stride;
825 *((__global float *)dst_addr) = out13;
826 dst_addr += dst_plane_stride;
827 *((__global float *)dst_addr) = out14;
828 dst_addr += dst_plane_stride;
829 *((__global float *)dst_addr) = out15;
830 dst_addr += dst_plane_stride;
831 *((__global float *)dst_addr) = out16;
832 dst_addr += dst_plane_stride;
833 *((__global float *)dst_addr) = out17;
834 dst_addr += dst_plane_stride;
835
836 *((__global float *)dst_addr) = out18;
837 dst_addr += dst_plane_stride;
838 *((__global float *)dst_addr) = out19;
839 dst_addr += dst_plane_stride;
840 *((__global float *)dst_addr) = out20;
841 dst_addr += dst_plane_stride;
842 *((__global float *)dst_addr) = out21;
843 dst_addr += dst_plane_stride;
844 *((__global float *)dst_addr) = out22;
845 dst_addr += dst_plane_stride;
846 *((__global float *)dst_addr) = out23;
847 dst_addr += dst_plane_stride;
848 *((__global float *)dst_addr) = out24;
849 dst_addr += dst_plane_stride;
850 *((__global float *)dst_addr) = out25;
851 dst_addr += dst_plane_stride;
852 *((__global float *)dst_addr) = out26;
853 dst_addr += dst_plane_stride;
854 *((__global float *)dst_addr) = out27;
855 dst_addr += dst_plane_stride;
856 *((__global float *)dst_addr) = out28;
857 dst_addr += dst_plane_stride;
858 *((__global float *)dst_addr) = out29;
859 dst_addr += dst_plane_stride;
860
861 // Row5
862 z_coord = (z * 4) - PAD_TOP + 5;
863 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
864 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
865 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
866 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
867 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
868 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
869
870 float d50 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
871 float d51 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
872 float d52 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
873 float d53 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
874 float d54 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
875 float d55 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
876
877 // Channels [30, 35]
878 out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34;
879 out1 = -16.0f * d11 - 16.0f * d12 + 4.0f * d13 + 20.0f * d31 + 20.0f * d32 - 5.0f * d33 - 4.0f * d51 - 4.0f * d52 + d53 + d54 + 4.0f * d14 - 5.0f * d34;
880 out2 = 16.0f * d11 - 16.0f * d12 - 4.0f * d13 - 20.0f * d31 + 20.0f * d32 + 5.0f * d33 + 4.0f * d51 - 4.0f * d52 - d53 + d54 + 4.0f * d14 - 5.0f * d34;
881 out3 = -8.0f * d11 - 4.0f * d12 + 8.0f * d13 + 10.0f * d31 - 10.0f * d33 + 5.0f * d32 - 2.0f * d51 + 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
882 out4 = 8.0f * d11 - 4.0f * d12 - 8.0f * d13 - 10.0f * d31 + 5.0f * d32 + 10.0f * d33 + 2.0f * d51 - 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
883 out5 = 16.0f * d11 - 20.0f * d13 + 4.0f * d15 - 20.0f * d31 + 25.0f * d33 - 5.0f * d35 + 4.0f * d51 - 5.0f * d53 + d55;
884
885 *((__global float *)dst_addr) = out0;
886 dst_addr += dst_plane_stride;
887 *((__global float *)dst_addr) = out1;
888 dst_addr += dst_plane_stride;
889 *((__global float *)dst_addr) = out2;
890 dst_addr += dst_plane_stride;
891 *((__global float *)dst_addr) = out3;
892 dst_addr += dst_plane_stride;
893 *((__global float *)dst_addr) = out4;
894 dst_addr += dst_plane_stride;
895 *((__global float *)dst_addr) = out5;
896 dst_addr += dst_plane_stride;
897}
898
899#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
900
901#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
902 ({ \
903 comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
904 comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
905 comm_fact.s2 = 2.5f * tmp.s3; \
906 comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
907 comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
908 comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
909 comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
910 \
911 out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
912 out.s1 = comm_fact.s0 + comm_fact.s1; \
913 out.s2 = comm_fact.s0 - comm_fact.s1; \
914 out.s3 = comm_fact.s3 + comm_fact.s4; \
915 out.s4 = comm_fact.s4 - comm_fact.s3; \
916 out.s5 = comm_fact.s5 + comm_fact.s6; \
917 out.s6 = comm_fact.s5 - comm_fact.s6; \
918 out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
919 })
920
921/** This OpenCL kernel computes the input transform when the kernel size is 5x5 and the output tile is 4x4 when the data layout is NCHW
922 *
923 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
924 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
925 *
926 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
927 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
928 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
929 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
930 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
931 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
932 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
933 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
934 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
935 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
936 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
937 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
938 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
939 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
940 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
941 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
942 */
943__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
944 TENSOR3D_DECLARATION(src),
945 TENSOR3D_DECLARATION(dst))
946{
947 int x = get_global_id(0);
948 int y = get_global_id(1);
949 int z = get_global_id(2);
950
951 // Compute input address
952 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
953
954 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
955
956 // Load 8x8 input tile
957 const float8 in_row0 = vload8(0, (__global float *)(src_addr + 0 * src_stride_y));
958 const float8 in_row1 = vload8(0, (__global float *)(src_addr + 1 * src_stride_y));
959 const float8 in_row2 = vload8(0, (__global float *)(src_addr + 2 * src_stride_y));
960 const float8 in_row3 = vload8(0, (__global float *)(src_addr + 3 * src_stride_y));
961 const float8 in_row4 = vload8(0, (__global float *)(src_addr + 4 * src_stride_y));
962 const float8 in_row5 = vload8(0, (__global float *)(src_addr + 5 * src_stride_y));
963 const float8 in_row6 = vload8(0, (__global float *)(src_addr + 6 * src_stride_y));
964 const float8 in_row7 = vload8(0, (__global float *)(src_addr + 7 * src_stride_y));
965
966 // Calculate common factors for intermediate tensor
967 float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
968 float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
969 float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
970
971 // Calculate intermediate tensor and reuse common factor vectors
972 const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
973 const float8 tmp1 = comm_fact0 + comm_fact1;
974 const float8 tmp2 = comm_fact0 - comm_fact1;
975
976 comm_fact0 = 2.5f * in_row3;
977 comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
978
979 const float8 tmp3 = comm_fact1 + comm_fact2;
980 const float8 tmp4 = comm_fact2 - comm_fact1;
981
982 comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
983 comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
984
985 const float8 tmp5 = comm_fact1 + comm_fact2;
986 const float8 tmp6 = comm_fact2 - comm_fact1;
987 const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
988
989 // Calculate output rows (reuse comm_fact0 vector)
990 float8 out0, out1, out2, out3, out4, out5, out6, out7;
991
992 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
993 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
994 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
995 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
996 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
997 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
998 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
999 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
1000
1001 // Store values across the 64 channels
1002 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
1003
1004 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1005 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1006 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1007 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1008 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1009 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1010 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1011 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
1012 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1013 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1014 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1015 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1016 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1017 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1018 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1019 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1020 *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1021 *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1022 *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1023 *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1024 *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1025 *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1026 *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1027 *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1028 *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1029 *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1030 *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1031 *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1032 *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1033 *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1034 *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1035 *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1036 *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1037 *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1038 *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1039 *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1040 *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1041 *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1042 *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1043 *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1044 *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1045 *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1046 *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1047 *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1048 *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1049 *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1050 *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1051 *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1052 *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1053 *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1054 *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1055 *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1056 *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1057 *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1058 *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1059 *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1060 *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1061 *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1062 *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1063 *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1064 *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1065 *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1066 *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1067 *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
1068}
1069
1070#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1071/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
1072 *
1073 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1074 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1075 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1076 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1077 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1078 *
1079 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1080 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1081 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1082 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1083 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1084 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1085 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1086 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1087 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1088 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1089 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1090 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1091 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1092 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1093 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1094 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1095 */
1096__kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
1097 TENSOR3D_DECLARATION(src),
1098 TENSOR3D_DECLARATION(dst))
1099{
1100 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
1101 src_stride_x,
1102 src_step_x,
1103 src_stride_y,
1104 src_step_y,
1105 src_stride_z,
1106 src_step_z,
1107 src_offset_first_element_in_bytes,
1108 dst_ptr,
1109 dst_stride_x,
1110 dst_step_x,
1111 dst_stride_y,
1112 dst_step_y,
1113 dst_stride_z,
1114 dst_step_z,
1115 dst_offset_first_element_in_bytes);
1116}
1117
1118/** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
1119 *
1120 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1121 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1122 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1123 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1124 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1125 *
1126 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1127 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1128 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1129 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1130 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1131 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1132 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1133 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1134 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1135 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1136 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1137 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1138 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1139 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1140 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1141 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1142 */
1143__kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
1144 TENSOR3D_DECLARATION(src),
1145 TENSOR3D_DECLARATION(dst))
1146{
1147 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
1148 src_stride_x,
1149 src_step_x,
1150 src_stride_y,
1151 src_step_y,
1152 src_stride_z,
1153 src_step_z,
1154 src_offset_first_element_in_bytes,
1155 dst_ptr,
1156 dst_stride_x,
1157 dst_step_x,
1158 dst_stride_y,
1159 dst_step_y,
1160 dst_stride_z,
1161 dst_step_z,
1162 dst_offset_first_element_in_bytes);
1163}
1164
1165/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
1166 *
1167 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1168 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1169 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1170 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1171 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1172 *
1173 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1174 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1175 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1176 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1177 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1178 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1179 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1180 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1181 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1182 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1183 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1184 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1185 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1186 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1187 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1188 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1189 */
1190__kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
1191 TENSOR3D_DECLARATION(src),
1192 TENSOR3D_DECLARATION(dst))
1193{
1194 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
1195 src_stride_x,
1196 src_step_x,
1197 src_stride_y,
1198 src_step_y,
1199 src_stride_z,
1200 src_step_z,
1201 src_offset_first_element_in_bytes,
1202 dst_ptr,
1203 dst_stride_x,
1204 dst_step_x,
1205 dst_stride_y,
1206 dst_step_y,
1207 dst_stride_z,
1208 dst_step_z,
1209 dst_offset_first_element_in_bytes);
1210}
1211#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1212
1213#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1214/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
1215 *
1216 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1217 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1218 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1219 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1220 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1221 *
1222 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1223 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1224 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1225 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1226 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1227 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1228 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1229 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1230 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1231 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1232 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1233 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1234 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1235 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1236 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1237 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1238 */
1239__kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
1240 TENSOR3D_DECLARATION(src),
1241 TENSOR3D_DECLARATION(dst))
1242{
1243 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
1244 src_stride_x,
1245 src_step_x,
1246 src_stride_y,
1247 src_step_y,
1248 src_stride_z,
1249 src_step_z,
1250 src_offset_first_element_in_bytes,
1251 dst_ptr,
1252 dst_stride_x,
1253 dst_step_x,
1254 dst_stride_y,
1255 dst_step_y,
1256 dst_stride_z,
1257 dst_step_z,
1258 dst_offset_first_element_in_bytes);
1259}
1260
1261/** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
1262 *
1263 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1264 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1265 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1266 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1267 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1268 *
1269 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1270 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1271 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1272 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1273 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1274 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1275 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1276 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1277 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1278 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1279 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1280 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1281 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1282 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1283 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1284 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1285 */
1286__kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
1287 TENSOR3D_DECLARATION(src),
1288 TENSOR3D_DECLARATION(dst))
1289{
1290 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
1291 src_stride_x,
1292 src_step_x,
1293 src_stride_y,
1294 src_step_y,
1295 src_stride_z,
1296 src_step_z,
1297 src_offset_first_element_in_bytes,
1298 dst_ptr,
1299 dst_stride_x,
1300 dst_step_x,
1301 dst_stride_y,
1302 dst_step_y,
1303 dst_stride_z,
1304 dst_step_z,
1305 dst_offset_first_element_in_bytes);
1306}
1307
1308/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
1309 *
1310 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1311 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1312 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1313 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
1314 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1315 *
1316 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1317 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1318 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1319 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1320 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1321 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1322 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1323 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1324 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1325 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1326 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1327 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1328 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1329 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1330 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1331 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1332 */
1333__kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
1334 TENSOR3D_DECLARATION(src),
1335 TENSOR3D_DECLARATION(dst))
1336{
1337 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
1338 src_stride_x,
1339 src_step_x,
1340 src_stride_y,
1341 src_step_y,
1342 src_stride_z,
1343 src_step_z,
1344 src_offset_first_element_in_bytes,
1345 dst_ptr,
1346 dst_stride_x,
1347 dst_step_x,
1348 dst_stride_y,
1349 dst_step_y,
1350 dst_stride_z,
1351 dst_step_z,
1352 dst_offset_first_element_in_bytes);
1353}
1354#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1355
1356#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
1357/** This OpenCL kernel computes the input transform when the kernel size is 5x5 and the output tile is 4x4 when the data layout is NHWC
1358 *
1359 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1360 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1361 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1362 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
1363 *
1364 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1365 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1366 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1367 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1368 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1369 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1370 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1371 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1372 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1373 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1374 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1375 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1376 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1377 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1378 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1379 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1380 */
1381__kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
1382 TENSOR3D_DECLARATION(src),
1383 TENSOR3D_DECLARATION(dst))
1384{
1385 int x = get_global_id(0);
1386 int y = get_global_id(1);
1387 int z = get_global_id(2);
1388
1389 // Compute input address
1390 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float);
1391
1392 // Clamp coordinates. This clamp is valid for all rows
1393 int8 y_coord = (int8)(y * 4) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
1394 y_coord = clamp(y_coord, -1, SRC_DIM_1);
1395
1396 // Load 8x8 input tile
1397 float8 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
1398
1399 // Row0
1400 int z_coord = (z * 4) - PAD_TOP + 0;
1401 int8 valid_y = select(y_coord, -1, (int8)z_coord < 0); // If z < 0, set y to -1
1402 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1403 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); // Clamp z coordinate
1404
1405 in_row0.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1406 in_row0.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1407 in_row0.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1408 in_row0.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1409 in_row0.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1410 in_row0.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1411 in_row0.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1412 in_row0.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1413
1414 // Row1
1415 z_coord = (z * 4) - PAD_TOP + 1;
1416 valid_y = select(y_coord, -1, (int8)z_coord < 0);
1417 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
1418 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1419
1420 in_row1.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1421 in_row1.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1422 in_row1.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1423 in_row1.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1424 in_row1.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1425 in_row1.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1426 in_row1.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1427 in_row1.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1428
1429 // Row2
1430 z_coord = (z * 4) - PAD_TOP + 2;
1431 valid_y = select(y_coord, -1, (int8)z_coord < 0);
1432 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
1433 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1434
1435 in_row2.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1436 in_row2.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1437 in_row2.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1438 in_row2.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1439 in_row2.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1440 in_row2.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1441 in_row2.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1442 in_row2.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1443
1444 // Row3
1445 z_coord = (z * 4) - PAD_TOP + 3;
1446 valid_y = select(y_coord, -1, (int8)z_coord < 0);
1447 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
1448 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1449
1450 in_row3.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1451 in_row3.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1452 in_row3.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1453 in_row3.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1454 in_row3.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1455 in_row3.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1456 in_row3.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1457 in_row3.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1458
1459 // Row4
1460 z_coord = (z * 4) - PAD_TOP + 4;
1461 valid_y = select(y_coord, -1, (int8)z_coord < 0);
1462 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
1463 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1464
1465 in_row4.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1466 in_row4.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1467 in_row4.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1468 in_row4.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1469 in_row4.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1470 in_row4.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1471 in_row4.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1472 in_row4.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1473
1474 // Row5
1475 z_coord = (z * 4) - PAD_TOP + 5;
1476 valid_y = select(y_coord, -1, (int8)z_coord < 0);
1477 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
1478 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1479
1480 in_row5.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1481 in_row5.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1482 in_row5.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1483 in_row5.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1484 in_row5.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1485 in_row5.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1486 in_row5.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1487 in_row5.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1488
1489 // Row6
1490 z_coord = (z * 4) - PAD_TOP + 6;
1491 valid_y = select(y_coord, -1, (int8)z_coord < 0);
1492 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
1493 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1494
1495 in_row6.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1496 in_row6.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1497 in_row6.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1498 in_row6.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1499 in_row6.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1500 in_row6.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1501 in_row6.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1502 in_row6.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1503
1504 // Row7
1505 z_coord = (z * 4) - PAD_TOP + 7;
1506 valid_y = select(y_coord, -1, (int8)z_coord < 0);
1507 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
1508 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1509
1510 in_row7.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1511 in_row7.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1512 in_row7.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1513 in_row7.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1514 in_row7.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1515 in_row7.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1516 in_row7.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1517 in_row7.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1518
1519 // Calculate common factors for intermediate tensor
1520 float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
1521 float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
1522 float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
1523
1524 // Calculate intermediate tensor and reuse common factor vectors
1525 const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
1526 const float8 tmp1 = comm_fact0 + comm_fact1;
1527 const float8 tmp2 = comm_fact0 - comm_fact1;
1528
1529 comm_fact0 = 2.5f * in_row3;
1530 comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
1531
1532 const float8 tmp3 = comm_fact1 + comm_fact2;
1533 const float8 tmp4 = comm_fact2 - comm_fact1;
1534
1535 comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
1536 comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
1537
1538 const float8 tmp5 = comm_fact1 + comm_fact2;
1539 const float8 tmp6 = comm_fact2 - comm_fact1;
1540 const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
1541
1542 // Calculate output rows (reuse comm_fact0 vector)
1543 float8 out0, out1, out2, out3, out4, out5, out6, out7;
1544
1545 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1546 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
1547 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
1548 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
1549 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
1550 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
1551 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
1552 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
1553
1554 // Store values across the 64 channels
1555 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(float) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
1556
1557 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1558 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1559 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1560 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1561 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1562 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1563 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1564 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
1565 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1566 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1567 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1568 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1569 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1570 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1571 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1572 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1573 *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1574 *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1575 *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1576 *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1577 *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1578 *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1579 *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1580 *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1581 *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1582 *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1583 *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1584 *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1585 *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1586 *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1587 *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1588 *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1589 *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1590 *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1591 *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1592 *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1593 *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1594 *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1595 *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1596 *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1597 *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1598 *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1599 *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1600 *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1601 *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1602 *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1603 *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1604 *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1605 *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1606 *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1607 *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1608 *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1609 *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1610 *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1611 *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1612 *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1613 *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1614 *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1615 *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1616 *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1617 *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1618 *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1619 *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1620 *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
1621}
1622#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
1623#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)