blob: da18e4ab5bee939e4c27786d6ed01867d19292a8 [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodiced28b7512018-07-06 12:59:28 +010026#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
27 ({ \
28 comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
29 comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
30 comm_fact.s2 = 2.5f * tmp.s3; \
31 comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
32 comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
33 comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
34 comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
35 \
36 out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
37 out.s1 = comm_fact.s0 + comm_fact.s1; \
38 out.s2 = comm_fact.s0 - comm_fact.s1; \
39 out.s3 = comm_fact.s3 + comm_fact.s4; \
40 out.s4 = comm_fact.s4 - comm_fact.s3; \
41 out.s5 = comm_fact.s5 + comm_fact.s6; \
42 out.s6 = comm_fact.s5 - comm_fact.s6; \
43 out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
44 })
45
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010046#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
47/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
48 *
49 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
50 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
51 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
52 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
53 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
54 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
55 *
56 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
57 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
58 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
59 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
60 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
61 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
62 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
63 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
64 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
65 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
66 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
67 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
68 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
69 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
70 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
71 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
72 */
73__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
74 TENSOR3D_DECLARATION(src),
75 TENSOR3D_DECLARATION(dst))
76{
77 int x = get_global_id(0);
78 int y = get_global_id(1);
79 int z = get_global_id(2);
80
81 // Compute input address
82 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
83
84 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
85
86#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
87 float4 in_row0 = vload4(0, (__global float *)(src_addr));
88#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
89 float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
90 *((__global float *)(src_addr + 1 * src_stride_y)),
91 *((__global float *)(src_addr + 2 * src_stride_y)),
92 *((__global float *)(src_addr + 3 * src_stride_y)));
93#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +010094 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
95 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
96 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
97 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010098#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
99
100 float4 tmp0 = in_row0;
101
102#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
103 tmp0 -= in_row2;
104#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
105
106 float out00 = tmp0.s0 - tmp0.s2;
107 float out01 = tmp0.s1 + tmp0.s2;
108 float out02 = tmp0.s2 - tmp0.s1;
109 float out03 = tmp0.s1 - tmp0.s3;
110
111#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
112 float4 tmp1 = in_row1 + in_row2;
113 float4 tmp2 = in_row2 - in_row1;
114 float4 tmp3 = in_row1 - in_row3;
115
116 float out10 = tmp1.s0 - tmp1.s2;
117 float out11 = tmp1.s1 + tmp1.s2;
118 float out12 = tmp1.s2 - tmp1.s1;
119 float out13 = tmp1.s1 - tmp1.s3;
120
121 float out20 = tmp2.s0 - tmp2.s2;
122 float out21 = tmp2.s1 + tmp2.s2;
123 float out22 = tmp2.s2 - tmp2.s1;
124 float out23 = tmp2.s1 - tmp2.s3;
125
126 float out30 = tmp3.s0 - tmp3.s2;
127 float out31 = tmp3.s1 + tmp3.s2;
128 float out32 = tmp3.s2 - tmp3.s1;
129 float out33 = tmp3.s1 - tmp3.s3;
130#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
131
132 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
133
134 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
135 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
136 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
137 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
138
139#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
140 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out10;
141 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out11;
142 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out12;
143 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out13;
144 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out20;
145 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out21;
146 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out22;
147 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out23;
148 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out30;
149 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out31;
150 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out32;
151 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out33;
152#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
153}
154
155/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
156 *
157 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
158 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
159 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
160 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
161 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
162 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
163 *
164 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
165 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
166 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
167 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
168 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
169 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
170 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
171 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
172 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
173 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
174 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
175 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
176 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
177 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
178 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
179 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
180 */
181__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
182 TENSOR3D_DECLARATION(src),
183 TENSOR3D_DECLARATION(dst))
184{
185 int x = get_global_id(0);
186 int y = get_global_id(1);
187 int z = get_global_id(2) * 2;
188
189 // Compute input address
190 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
191
192 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
193
194#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
195 float4 in_row0 = vload4(0, (__global float *)(src_addr));
196#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
197 float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
198 *((__global float *)(src_addr + 1 * src_stride_y)),
199 *((__global float *)(src_addr + 2 * src_stride_y)),
200 *((__global float *)(src_addr + 3 * src_stride_y)));
201#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100202 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
203 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
204 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
205 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100206#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
207
208 src_addr += src_stride_z;
209#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
210 float4 in_row4 = vload4(0, (__global float *)(src_addr));
211#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
212 float4 in_row4 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
213 *((__global float *)(src_addr + 1 * src_stride_y)),
214 *((__global float *)(src_addr + 2 * src_stride_y)),
215 *((__global float *)(src_addr + 3 * src_stride_y)));
216#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100217 float4 in_row4 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
218 float4 in_row5 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
219 float4 in_row6 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
220 float4 in_row7 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100221#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
222
223 float4 tmp0 = in_row0;
224 float4 tmp4 = in_row4;
225
226#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
227 tmp0 -= in_row2;
228 tmp4 -= in_row6;
229#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
230
231 float2 out00 = (float2)(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
232 float2 out01 = (float2)(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
233 float2 out02 = (float2)(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
234 float2 out03 = (float2)(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
235
236#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
237 float4 tmp1 = in_row1 + in_row2;
238 float4 tmp2 = in_row2 - in_row1;
239 float4 tmp3 = in_row1 - in_row3;
240
241 float4 tmp5 = in_row5 + in_row6;
242 float4 tmp6 = in_row6 - in_row5;
243 float4 tmp7 = in_row5 - in_row7;
244
245 float2 out10 = (float2)(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
246 float2 out11 = (float2)(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
247 float2 out12 = (float2)(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
248 float2 out13 = (float2)(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
249
250 float2 out20 = (float2)(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
251 float2 out21 = (float2)(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
252 float2 out22 = (float2)(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
253 float2 out23 = (float2)(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
254
255 float2 out30 = (float2)(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
256 float2 out31 = (float2)(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
257 float2 out32 = (float2)(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
258 float2 out33 = (float2)(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
259#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
260
261 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
262
263 vstore2(out00, 0, (__global float *)(dst_addr + 0 * dst_stride_z));
264 vstore2(out01, 0, (__global float *)(dst_addr + 1 * dst_stride_z));
265 vstore2(out02, 0, (__global float *)(dst_addr + 2 * dst_stride_z));
266 vstore2(out03, 0, (__global float *)(dst_addr + 3 * dst_stride_z));
267
268#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
269 vstore2(out10, 0, (__global float *)(dst_addr + 4 * dst_stride_z));
270 vstore2(out11, 0, (__global float *)(dst_addr + 5 * dst_stride_z));
271 vstore2(out12, 0, (__global float *)(dst_addr + 6 * dst_stride_z));
272 vstore2(out13, 0, (__global float *)(dst_addr + 7 * dst_stride_z));
273 vstore2(out20, 0, (__global float *)(dst_addr + 8 * dst_stride_z));
274 vstore2(out21, 0, (__global float *)(dst_addr + 9 * dst_stride_z));
275 vstore2(out22, 0, (__global float *)(dst_addr + 10 * dst_stride_z));
276 vstore2(out23, 0, (__global float *)(dst_addr + 11 * dst_stride_z));
277 vstore2(out30, 0, (__global float *)(dst_addr + 12 * dst_stride_z));
278 vstore2(out31, 0, (__global float *)(dst_addr + 13 * dst_stride_z));
279 vstore2(out32, 0, (__global float *)(dst_addr + 14 * dst_stride_z));
280 vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z));
281#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
282}
283
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100284/** This OpenCL kernel computes the input transform when the output tile is 4x4/4x1 or 1x4, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100285 *
286 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
287 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
288 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
289 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
290 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
291 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
292 *
293 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
294 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
295 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
296 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
297 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
298 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
299 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
300 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
301 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
302 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
303 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
304 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
305 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
306 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
307 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
308 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
309 */
310__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
311 TENSOR3D_DECLARATION(src),
312 TENSOR3D_DECLARATION(dst))
313{
314 int x = get_global_id(0);
315 int y = get_global_id(1);
316 int z = get_global_id(2);
317
318 // Compute input address
319 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
320
321 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
322
323#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
324 // Row0
325 float4 d00 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
326 *((__global float *)(src_addr + 1 * src_stride_y)),
327 *((__global float *)(src_addr + 2 * src_stride_y)),
328 *((__global float *)(src_addr + 3 * src_stride_y)));
329 float2 d01 = (float2)(*((__global float *)(src_addr + 4 * src_stride_y)),
330 *((__global float *)(src_addr + 5 * src_stride_y)));
331#else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
332 // Row0
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100333 float4 d00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
334 float2 d01 = vload2(2, (__global float *)(src_addr + 0 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100335#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
336
337 float out0 = 0.0f;
338 float out1 = 0.0f;
339 float out2 = 0.0f;
340 float out3 = 0.0f;
341 float out4 = 0.0f;
342 float out5 = 0.0f;
343
344 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
345 out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
346 out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
347 out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
348 out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
349 out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
350 out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
351
352#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
353 // Row4
354 float4 d40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
355 float2 d41 = vload2(2, (__global float *)(src_addr + 4 * src_stride_y));
356
357 // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
358 float k0 = d41.s0;
359 float k1 = d41.s0;
360 float k2 = d41.s0;
361 float k3 = d41.s0;
362 float k4 = d41.s0;
363 float k5 = 0.0f;
364
365 k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
366 k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
367 k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
368 k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
369 k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
370 k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
371
372 out0 += k0;
373 out1 += k1;
374 out2 += k2;
375 out3 += k3;
376 out4 += k4;
377 out5 += k5;
378
379 // Row2
380 float4 d20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
381 float2 d21 = vload2(2, (__global float *)(src_addr + 2 * src_stride_y));
382
383 out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
384 out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
385 out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
386 out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
387 out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
388 out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
389#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
390
391 // Compute destination address
392 __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
393
394 uint dst_plane_stride = dst_stride_z / sizeof(float);
395
396 *(dst_addr) = out0;
397 dst_addr += dst_plane_stride;
398 *(dst_addr) = out1;
399 dst_addr += dst_plane_stride;
400 *(dst_addr) = out2;
401 dst_addr += dst_plane_stride;
402 *(dst_addr) = out3;
403 dst_addr += dst_plane_stride;
404 *(dst_addr) = out4;
405 dst_addr += dst_plane_stride;
406 *(dst_addr) = out5;
407 dst_addr += dst_plane_stride;
408
409#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
410 float out6 = k0;
411 float out7 = k1;
412 float out8 = k2;
413 float out9 = k3;
414 float out10 = k4;
415 float out11 = k5;
416 float out12 = k0;
417 float out13 = k1;
418 float out14 = k2;
419 float out15 = k3;
420 float out16 = k4;
421 float out17 = k5;
422 float out18 = k0;
423 float out19 = k1;
424 float out20 = k2;
425 float out21 = k3;
426 float out22 = k4;
427 float out23 = k5;
428 float out24 = k0;
429 float out25 = k1;
430 float out26 = k2;
431 float out27 = k3;
432 float out28 = k4;
433 float out29 = k5;
434
435 // Row1
436 float4 d10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
437 float2 d11 = vload2(2, (__global float *)(src_addr + 1 * src_stride_y));
438
439 // Row3
440 float4 d30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
441 float2 d31 = vload2(2, (__global float *)(src_addr + 3 * src_stride_y));
442
443 // Compute common parts for the channels between [6, 29]
444 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
445 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
446 float part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
447 float part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
448 float part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
449 float part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
450 float part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
451 float part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
452 float part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
453 float part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
454 float part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
455 float part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
456 float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
457 float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
458
459 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
460 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
461 float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
462 float part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
463 float part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
464 float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
465 float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
466 float part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
467 float part18 = part6 * 0.25f; // d20.s2 - d21.s0
468 float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
469 float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
470 float part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
471 float part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
472 float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
473
474 out6 += part0 - part1;
475 out12 += part0 + part1;
476 out7 += part2 + part3 + part4 + part5;
477 out8 += part2 - part3 + part4 - part5;
478 out13 += part2 + part3 - part4 - part5;
479 out14 += part2 - part3 - part4 + part5;
480 out9 += part6 + part7 + part8 + part9;
481 out10 += part6 - part7 + part8 - part9;
482 out15 += part6 - part7 - part8 + part9;
483 out16 += part6 + part7 - part8 - part9;
484 out11 += part10 + part11;
485 out17 += part10 - part11;
486
487 out18 += part13 - part12;
488 out24 += part13 + part12;
489 out19 += part14 + part15 + part16 + part17;
490 out20 += part14 - part15 + part16 - part17;
491 out25 += part14 - part15 - part16 + part17;
492 out26 += part14 + part15 - part16 - part17;
493 out21 += part18 + part19 + part20 + part21;
494 out22 += part18 - part19 + part20 - part21;
495 out27 += part18 - part19 - part20 + part21;
496 out28 += part18 + part19 - part20 - part21;
497 out23 += part22 + part23;
498 out29 += part22 - part23;
499
500 *(dst_addr) = out6;
501 dst_addr += dst_plane_stride;
502 *(dst_addr) = out7;
503 dst_addr += dst_plane_stride;
504 *(dst_addr) = out8;
505 dst_addr += dst_plane_stride;
506 *(dst_addr) = out9;
507 dst_addr += dst_plane_stride;
508 *(dst_addr) = out10;
509 dst_addr += dst_plane_stride;
510 *(dst_addr) = out11;
511 dst_addr += dst_plane_stride;
512 *(dst_addr) = out12;
513 dst_addr += dst_plane_stride;
514 *(dst_addr) = out13;
515 dst_addr += dst_plane_stride;
516 *(dst_addr) = out14;
517 dst_addr += dst_plane_stride;
518 *(dst_addr) = out15;
519 dst_addr += dst_plane_stride;
520 *(dst_addr) = out16;
521 dst_addr += dst_plane_stride;
522 *(dst_addr) = out17;
523 dst_addr += dst_plane_stride;
524
525 *(dst_addr) = out18;
526 dst_addr += dst_plane_stride;
527 *(dst_addr) = out19;
528 dst_addr += dst_plane_stride;
529 *(dst_addr) = out20;
530 dst_addr += dst_plane_stride;
531 *(dst_addr) = out21;
532 dst_addr += dst_plane_stride;
533 *(dst_addr) = out22;
534 dst_addr += dst_plane_stride;
535 *(dst_addr) = out23;
536 dst_addr += dst_plane_stride;
537 *(dst_addr) = out24;
538 dst_addr += dst_plane_stride;
539 *(dst_addr) = out25;
540 dst_addr += dst_plane_stride;
541 *(dst_addr) = out26;
542 dst_addr += dst_plane_stride;
543 *(dst_addr) = out27;
544 dst_addr += dst_plane_stride;
545 *(dst_addr) = out28;
546 dst_addr += dst_plane_stride;
547 *(dst_addr) = out29;
548 dst_addr += dst_plane_stride;
549
550 // Row5
551 float4 d50 = vload4(0, (__global float *)(src_addr + 5 * src_stride_y));
552 float2 d51 = vload2(2, (__global float *)(src_addr + 5 * src_stride_y));
553
554 // Channels [30, 35]
555 out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
556 out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
557 out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
558 out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
559 out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
560 out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
561
562 *(dst_addr) = out0;
563 dst_addr += dst_plane_stride;
564 *(dst_addr) = out1;
565 dst_addr += dst_plane_stride;
566 *(dst_addr) = out2;
567 dst_addr += dst_plane_stride;
568 *(dst_addr) = out3;
569 dst_addr += dst_plane_stride;
570 *(dst_addr) = out4;
571 dst_addr += dst_plane_stride;
572 *(dst_addr) = out5;
573 dst_addr += dst_plane_stride;
574#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
575}
576
577#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arena149fdf32018-07-04 17:03:33 +0100578/** This OpenCL kernel computes the input transform when the output tile is 4x4, 4x1 or 1x4, the filter size 3x3, 3x1 or 1x3 and the data layout is NHWC
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100579 *
580 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
581 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
582 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
583 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arena149fdf32018-07-04 17:03:33 +0100584 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
585 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
586 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
587 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100588 *
589 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
590 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
591 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
592 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
593 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
594 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
595 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
596 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
597 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
598 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
599 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
600 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
601 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
602 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
603 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
604 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
605 */
606__kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
607 TENSOR3D_DECLARATION(src),
608 TENSOR3D_DECLARATION(dst))
609{
610 int x = get_global_id(0);
611 int y = get_global_id(1);
612 int z = get_global_id(2);
613
Giorgio Arena149fdf32018-07-04 17:03:33 +0100614 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100615
616 // Clamp coordinates. This clamp is valid for all rows
Giorgio Arena149fdf32018-07-04 17:03:33 +0100617 int4 y_coord0 = (int4)(y * OUTPUT_TILE_W) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT;
618 int2 y_coord1 = (int2)(y * OUTPUT_TILE_W) + (int2)(4, 5) - (int2)PAD_LEFT;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100619 y_coord0 = clamp(y_coord0, (int4) - 1, (int4)SRC_DIM_1);
620 y_coord1 = clamp(y_coord1, (int2) - 1, (int2)SRC_DIM_1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100621
Giorgio Arena149fdf32018-07-04 17:03:33 +0100622 int z_coord;
623 int4 valid_y0;
624 int2 valid_y1;
625
626#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100627 // Row4
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100628 z_coord = (z * 4) - (int)PAD_TOP + 4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100629
630 // If z < 0, set y to -1
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100631 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
632 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100633 // If z >= SRC_DIM_2, set y to SRC_DIM_2
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100634 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
635 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100636
637 // Clamp z coordinate
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100638 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100639
640 float d40 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
641 float d41 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
642 float d42 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
643 float d43 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
644 float d44 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
645 float d45 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
646
647 float k0 = d44;
648 float k1 = d44;
649 float k2 = d44;
650 float k3 = d44;
651 float k4 = d44;
652 float k5 = (float)0.0f;
653
654 k0 += 4.0f * d40 - 5.0f * d42;
655 k1 += -4.0f * d41 - 4.0f * d42 + d43;
656 k2 += 4.0f * d41 - 4.0f * d42 - d43;
657 k3 += -2.0f * d41 + 2.0f * d43 - d42;
658 k4 += 2.0f * d41 - 2.0f * d43 - d42;
659 k5 += 4.0f * d41 - 5.0f * d43 + d45;
Giorgio Arena149fdf32018-07-04 17:03:33 +0100660#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100661
Giorgio Arena149fdf32018-07-04 17:03:33 +0100662#if !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100663 // Row0
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100664 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100665
666#if PAD_TOP != 0
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100667 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
668 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
669 valid_y0 = select(valid_y0, (int)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
670 valid_y1 = select(valid_y1, (int)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
671 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100672#else // PAD_TOP != 0
673 valid_y0 = y_coord0;
674 valid_y1 = y_coord1;
675#endif // if PAD_TOP == 0, we cannot read out of bound
676
677 float d00 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
678 float d01 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
679 float d02 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
680 float d03 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
681 float d04 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
682 float d05 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arena149fdf32018-07-04 17:03:33 +0100683#else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
684 int4 z_coords0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP;
685 int2 z_coords1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100686
Giorgio Arena149fdf32018-07-04 17:03:33 +0100687 valid_y0 = select((int4)y_coord0.s0, (int4) - 1, z_coords0 < (int4)0);
688 valid_y1 = select((int2)y_coord0.s0, (int2) - 1, z_coords1 < (int2)0);
689 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, z_coords0 >= (int4)SRC_DIM_2);
690 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, z_coords1 >= (int2)SRC_DIM_2);
691
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100692 z_coords0 = clamp((int4)z_coords0, (int4)0, (int4)((int)SRC_DIM_2 - 1));
693 z_coords1 = clamp((int2)z_coords1, (int2)0, (int2)((int)SRC_DIM_2 - 1));
Giorgio Arena149fdf32018-07-04 17:03:33 +0100694
695 float d00 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coords0.s0 * src_stride_z);
696 float d01 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coords0.s1 * src_stride_z);
697 float d02 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coords0.s2 * src_stride_z);
698 float d03 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coords0.s3 * src_stride_z);
699 float d04 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coords1.s0 * src_stride_z);
700 float d05 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coords1.s1 * src_stride_z);
701#endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
702
703 float out0 = 16.0f * d00 - 20.0f * d02 + 4.0f * d04;
704 float out1 = -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 4.0f * d04;
705 float out2 = 16.0f * d01 - 16.0f * d02 - 4.0f * d03 + 4.0f * d04;
706 float out3 = -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 4.0f * d04;
707 float out4 = 8.0f * d01 - 4.0f * d02 - 8.0f * d03 + 4.0f * d04;
708 float out5 = 16.0f * d01 - 20.0f * d03 + 4.0f * d05;
709
710#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100711 // Row2
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100712 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
713 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
714 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
715 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
716 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
717 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100718
719 float d20 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
720 float d21 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
721 float d22 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
722 float d23 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
723 float d24 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
724 float d25 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
725
Giorgio Arena149fdf32018-07-04 17:03:33 +0100726 out0 += k0;
727 out1 += k1;
728 out2 += k2;
729 out3 += k3;
730 out4 += k4;
731 out5 += k5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100732 float out6 = k0;
733 float out7 = k1;
734 float out8 = k2;
735 float out9 = k3;
736 float out10 = k4;
737 float out11 = k5;
738 float out12 = k0;
739 float out13 = k1;
740 float out14 = k2;
741 float out15 = k3;
742 float out16 = k4;
743 float out17 = k5;
744 float out18 = k0;
745 float out19 = k1;
746 float out20 = k2;
747 float out21 = k3;
748 float out22 = k4;
749 float out23 = k5;
750 float out24 = k0;
751 float out25 = k1;
752 float out26 = k2;
753 float out27 = k3;
754 float out28 = k4;
755 float out29 = k5;
756
757 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
Giorgio Arena149fdf32018-07-04 17:03:33 +0100758 out0 += -20.0f * d20 + 25.0f * d22 - 5.0f * d24;
759 out1 += 20.0f * d21 + 20.0f * d22 - 5.0f * d23 - 5.0f * d24;
760 out2 += -20.0f * d21 + 20.0f * d22 + 5.0f * d23 - 5.0f * d24;
761 out3 += 10.0f * d21 + 5.0f * d22 - 10.0f * d23 - 5.0f * d24;
762 out4 += -10.0f * d21 + 5.0f * d22 + 10.0f * d23 - 5.0f * d24;
763 out5 += -20.0f * d21 + 25.0f * d23 - 5.0f * d25;
764#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
765
766 // Compute destination address
767 __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(float) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
768 uint dst_plane_stride = dst_stride_z / sizeof(float);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100769
770 *((__global float *)dst_addr) = out0;
771 dst_addr += dst_plane_stride;
772 *((__global float *)dst_addr) = out1;
773 dst_addr += dst_plane_stride;
774 *((__global float *)dst_addr) = out2;
775 dst_addr += dst_plane_stride;
776 *((__global float *)dst_addr) = out3;
777 dst_addr += dst_plane_stride;
778 *((__global float *)dst_addr) = out4;
779 dst_addr += dst_plane_stride;
780 *((__global float *)dst_addr) = out5;
781 dst_addr += dst_plane_stride;
782
Giorgio Arena149fdf32018-07-04 17:03:33 +0100783#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100784 // Row1
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100785 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100786 // Row1 can never be out of bounds
787 valid_y0 = y_coord0;
788 valid_y1 = y_coord1;
789
790 float d10 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
791 float d11 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
792 float d12 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
793 float d13 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
794 float d14 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
795 float d15 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
796
797 // Row3
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100798 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
799 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
800 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
801 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
802 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
803 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
804 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100805
806 float d30 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
807 float d31 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
808 float d32 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
809 float d33 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
810 float d34 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
811 float d35 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
812
813 // Compute common parts for the channels between [6, 29]
814 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
815 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
816 float part0 = -16.0f * d20 + 20.0f * d22 - 4.0f * d24;
817 float part1 = 16.0f * d10 - 20.0f * d12 + 4.0f * d14 - 4.0f * d30 + 5.0f * d32 - d34;
818 float part2 = 16.0f * d22 - 4.0f * d24;
819 float part3 = 16.0f * d21 - 4.0f * d23;
820 float part4 = 16.0f * d12 - 4.0f * d14 - 4.0f * d32 + d34;
821 float part5 = 16.0f * d11 - 4.0f * d13 - 4.0f * d31 + d33;
822 float part6 = 4.0f * d22 - 4.0f * d24;
823 float part7 = 8.0f * d11 - 8.0f * d13 - 2.0f * d31 + 2.0f * d33;
824 float part8 = 4.0f * d12 - 4.0f * d14 - d32 + d34;
825 float part9 = 8.0f * d21 - 8.0f * d23;
826 float part10 = -16.0f * d21 + 20.0f * d23 - 4.0f * d25;
827 float part11 = -16.0f * d11 + 20.0f * d13 - 4.0f * d15 + 4.0f * d31 - 5.0f * d33 + d35;
828
829 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
830 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
831 float part12 = 8.0f * d10 - 10.0f * d12 + 2.0f * d14 - 8.0f * d30 + 10.0f * d32 - 2.0f * d34;
832 float part13 = part0 * 0.25f; // -4.0f * d20 + 5.0f * d22 - d24
833 float part14 = part2 * 0.25f; // 4.0f * d22 - d24
834 float part15 = 8.0f * d11 - 2.0f * d13 - 8.0f * d31 + 2.0f * d33;
835 float part16 = 8.0f * d12 - 2.0f * d14 - 8.0f * d32 + 2.0f * d34;
836 float part17 = part3 * 0.25f; // 4.0f * d21 - d23
837 float part18 = part6 * 0.25f; // d22 - d24
838 float part19 = 4.0f * d11 - 4.0f * d13 - 4.0f * d31 + 4.0f * d33;
839 float part20 = 2.0f * d12 - 2.0f * d14 - 2.0f * d32 + 2.0f * d34;
840 float part21 = part9 * 0.25f; // 2.0f * (d21 - d23)
841 float part22 = part10 * 0.25f; // - 4.0f * d21 + 5.0f * d23 - d25
842 float part23 = part11 * 0.5f + 6.0f * d31 - 7.5f * d33 + 1.5f * d35; // - 8.0f * d11 + 10.0f * d13 - 2.0f * d15 + 8.0f * d31 - 10.0f * d33 + 2.0f * d35;
843
844 out6 += part0 - part1;
845 out12 += part0 + part1;
846 out7 += part2 + part3 + part4 + part5;
847 out8 += part2 - part3 + part4 - part5;
848 out13 += part2 + part3 - part4 - part5;
849 out14 += part2 - part3 - part4 + part5;
850 out9 += part6 + part7 + part8 + part9;
851 out10 += part6 - part7 + part8 - part9;
852 out15 += part6 - part7 - part8 + part9;
853 out16 += part6 + part7 - part8 - part9;
854 out11 += part10 + part11;
855 out17 += part10 - part11;
856
857 out18 += part13 - part12;
858 out24 += part13 + part12;
859 out19 += part14 + part15 + part16 + part17;
860 out20 += part14 - part15 + part16 - part17;
861 out25 += part14 - part15 - part16 + part17;
862 out26 += part14 + part15 - part16 - part17;
863 out21 += part18 + part19 + part20 + part21;
864 out22 += part18 - part19 + part20 - part21;
865 out27 += part18 - part19 - part20 + part21;
866 out28 += part18 + part19 - part20 - part21;
867 out23 += part22 + part23;
868 out29 += part22 - part23;
869
870 *((__global float *)dst_addr) = out6;
871 dst_addr += dst_plane_stride;
872 *((__global float *)dst_addr) = out7;
873 dst_addr += dst_plane_stride;
874 *((__global float *)dst_addr) = out8;
875 dst_addr += dst_plane_stride;
876 *((__global float *)dst_addr) = out9;
877 dst_addr += dst_plane_stride;
878 *((__global float *)dst_addr) = out10;
879 dst_addr += dst_plane_stride;
880 *((__global float *)dst_addr) = out11;
881 dst_addr += dst_plane_stride;
882 *((__global float *)dst_addr) = out12;
883 dst_addr += dst_plane_stride;
884 *((__global float *)dst_addr) = out13;
885 dst_addr += dst_plane_stride;
886 *((__global float *)dst_addr) = out14;
887 dst_addr += dst_plane_stride;
888 *((__global float *)dst_addr) = out15;
889 dst_addr += dst_plane_stride;
890 *((__global float *)dst_addr) = out16;
891 dst_addr += dst_plane_stride;
892 *((__global float *)dst_addr) = out17;
893 dst_addr += dst_plane_stride;
894
895 *((__global float *)dst_addr) = out18;
896 dst_addr += dst_plane_stride;
897 *((__global float *)dst_addr) = out19;
898 dst_addr += dst_plane_stride;
899 *((__global float *)dst_addr) = out20;
900 dst_addr += dst_plane_stride;
901 *((__global float *)dst_addr) = out21;
902 dst_addr += dst_plane_stride;
903 *((__global float *)dst_addr) = out22;
904 dst_addr += dst_plane_stride;
905 *((__global float *)dst_addr) = out23;
906 dst_addr += dst_plane_stride;
907 *((__global float *)dst_addr) = out24;
908 dst_addr += dst_plane_stride;
909 *((__global float *)dst_addr) = out25;
910 dst_addr += dst_plane_stride;
911 *((__global float *)dst_addr) = out26;
912 dst_addr += dst_plane_stride;
913 *((__global float *)dst_addr) = out27;
914 dst_addr += dst_plane_stride;
915 *((__global float *)dst_addr) = out28;
916 dst_addr += dst_plane_stride;
917 *((__global float *)dst_addr) = out29;
918 dst_addr += dst_plane_stride;
919
920 // Row5
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100921 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
922 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
923 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
924 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
925 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
926 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
927 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100928
929 float d50 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
930 float d51 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
931 float d52 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
932 float d53 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
933 float d54 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
934 float d55 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
935
936 // Channels [30, 35]
937 out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34;
938 out1 = -16.0f * d11 - 16.0f * d12 + 4.0f * d13 + 20.0f * d31 + 20.0f * d32 - 5.0f * d33 - 4.0f * d51 - 4.0f * d52 + d53 + d54 + 4.0f * d14 - 5.0f * d34;
939 out2 = 16.0f * d11 - 16.0f * d12 - 4.0f * d13 - 20.0f * d31 + 20.0f * d32 + 5.0f * d33 + 4.0f * d51 - 4.0f * d52 - d53 + d54 + 4.0f * d14 - 5.0f * d34;
940 out3 = -8.0f * d11 - 4.0f * d12 + 8.0f * d13 + 10.0f * d31 - 10.0f * d33 + 5.0f * d32 - 2.0f * d51 + 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
941 out4 = 8.0f * d11 - 4.0f * d12 - 8.0f * d13 - 10.0f * d31 + 5.0f * d32 + 10.0f * d33 + 2.0f * d51 - 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
942 out5 = 16.0f * d11 - 20.0f * d13 + 4.0f * d15 - 20.0f * d31 + 25.0f * d33 - 5.0f * d35 + 4.0f * d51 - 5.0f * d53 + d55;
943
944 *((__global float *)dst_addr) = out0;
945 dst_addr += dst_plane_stride;
946 *((__global float *)dst_addr) = out1;
947 dst_addr += dst_plane_stride;
948 *((__global float *)dst_addr) = out2;
949 dst_addr += dst_plane_stride;
950 *((__global float *)dst_addr) = out3;
951 dst_addr += dst_plane_stride;
952 *((__global float *)dst_addr) = out4;
953 dst_addr += dst_plane_stride;
954 *((__global float *)dst_addr) = out5;
955 dst_addr += dst_plane_stride;
Giorgio Arena149fdf32018-07-04 17:03:33 +0100956#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100957}
958
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100959/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NHWC
Giorgio Arena149fdf32018-07-04 17:03:33 +0100960 *
961 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
962 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100963 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
964 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arena149fdf32018-07-04 17:03:33 +0100965 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
Giorgio Arena149fdf32018-07-04 17:03:33 +0100966 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100967 * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
968 * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Giorgio Arena149fdf32018-07-04 17:03:33 +0100969 *
970 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
971 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
972 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
973 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
974 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
975 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
976 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
977 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
978 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
979 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
980 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
981 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
982 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
983 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
984 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
985 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
986 */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100987__kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
Giorgio Arena149fdf32018-07-04 17:03:33 +0100988 TENSOR3D_DECLARATION(src),
989 TENSOR3D_DECLARATION(dst))
990{
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100991 int x = get_global_id(0);
992 int y = get_global_id(1);
993 int z = get_global_id(2);
994
995 // Compute input address
996 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float);
997
998#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
999 // Clamp coordinates. This clamp is valid for all rows
1000 int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001001 y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001002
1003 // Row0
1004 // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
1005 int z_coord = z * OUTPUT_TILE_H;
1006
1007 // Load the input tile
1008 float8 in_row0;
1009 in_row0.s0 = *(__global float *)(src_addr + y_coord.s0 * (int)src_stride_y + z_coord * src_stride_z);
1010 in_row0.s1 = *(__global float *)(src_addr + y_coord.s1 * (int)src_stride_y + z_coord * src_stride_z);
1011 in_row0.s2 = *(__global float *)(src_addr + y_coord.s2 * (int)src_stride_y + z_coord * src_stride_z);
1012 in_row0.s3 = *(__global float *)(src_addr + y_coord.s3 * (int)src_stride_y + z_coord * src_stride_z);
1013 in_row0.s4 = *(__global float *)(src_addr + y_coord.s4 * (int)src_stride_y + z_coord * src_stride_z);
1014 in_row0.s5 = *(__global float *)(src_addr + y_coord.s5 * (int)src_stride_y + z_coord * src_stride_z);
1015 in_row0.s6 = *(__global float *)(src_addr + y_coord.s6 * (int)src_stride_y + z_coord * src_stride_z);
1016 in_row0.s7 = *(__global float *)(src_addr + y_coord.s7 * (int)src_stride_y + z_coord * src_stride_z);
1017
1018 // Calculate common factors for intermediate tensor
1019 float8 comm_fact0 = 0.0f;
1020 float8 tmp0 = in_row0;
1021
1022 float8 out0 = (float8)0.0f;
1023
1024 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1025
1026#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1027 // We can skip the border clamping along the y dimension as we cannot read out-of-bound in case of 1x5 kernels
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001028 int y_coord = y * (int)OUTPUT_TILE_W;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001029
1030 // Row0
1031 // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
1032 int8 z_coord = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001033 int8 valid_y = select((int8)y_coord, (int8) - 1, z_coord < (int8)0); // If z < 0, set y to -1
1034 valid_y = select(valid_y, (int8)SRC_DIM_1, z_coord >= (int8)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1035 z_coord = clamp(z_coord, (int8)0, (int8)SRC_DIM_2 - 1); // Clamp z coordinate
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001036
1037 // Load the input tile
1038 float8 in_row0;
1039 in_row0.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord.s0 * src_stride_z);
1040 in_row0.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord.s1 * src_stride_z);
1041 in_row0.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord.s2 * src_stride_z);
1042 in_row0.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord.s3 * src_stride_z);
1043 in_row0.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord.s4 * src_stride_z);
1044 in_row0.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord.s5 * src_stride_z);
1045 in_row0.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord.s6 * src_stride_z);
1046 in_row0.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord.s7 * src_stride_z);
1047
1048 // Calculate common factors for intermediate tensor
1049 float8 comm_fact0 = 0.0f;
1050 float8 tmp0 = in_row0;
1051
1052 float8 out0 = (float8)0.0f;
1053
1054 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1055#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1056 float8 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
1057
1058 // Clamp coordinates. This clamp is valid for all rows
1059 int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001060 y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001061
1062 // Row0
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001063 int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
1064 int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1
1065 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1066 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); // Clamp z coordinate
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001067
1068 // Load the input tile
1069 in_row0.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1070 in_row0.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1071 in_row0.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1072 in_row0.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1073 in_row0.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1074 in_row0.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1075 in_row0.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1076 in_row0.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1077
1078 // Row1
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001079 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
1080 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1081 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1082 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001083
1084 in_row1.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1085 in_row1.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1086 in_row1.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1087 in_row1.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1088 in_row1.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1089 in_row1.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1090 in_row1.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1091 in_row1.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1092
1093 // Row2
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001094 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
1095 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1096 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1097 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001098
1099 in_row2.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1100 in_row2.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1101 in_row2.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1102 in_row2.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1103 in_row2.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1104 in_row2.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1105 in_row2.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1106 in_row2.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1107
1108 // Row3
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001109 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
1110 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1111 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1112 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001113
1114 in_row3.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1115 in_row3.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1116 in_row3.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1117 in_row3.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1118 in_row3.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1119 in_row3.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1120 in_row3.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1121 in_row3.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1122
1123 // Row4
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001124 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 4;
1125 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1126 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1127 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001128
1129 in_row4.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1130 in_row4.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1131 in_row4.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1132 in_row4.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1133 in_row4.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1134 in_row4.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1135 in_row4.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1136 in_row4.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1137
1138 // Row5
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001139 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
1140 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1141 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1142 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001143
1144 in_row5.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1145 in_row5.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1146 in_row5.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1147 in_row5.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1148 in_row5.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1149 in_row5.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1150 in_row5.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1151 in_row5.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1152
1153 // Row6
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001154 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 6;
1155 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1156 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1157 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001158
1159 in_row6.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1160 in_row6.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1161 in_row6.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1162 in_row6.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1163 in_row6.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1164 in_row6.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1165 in_row6.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1166 in_row6.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1167
1168 // Row7
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001169 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 7;
1170 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1171 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1172 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001173
1174 in_row7.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1175 in_row7.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1176 in_row7.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1177 in_row7.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1178 in_row7.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1179 in_row7.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1180 in_row7.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1181 in_row7.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
1182
1183 float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
1184 float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
1185 float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
1186
1187 // Calculate intermediate tensor and reuse common factor vectors
1188 const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
1189 const float8 tmp1 = comm_fact0 + comm_fact1;
1190 const float8 tmp2 = comm_fact0 - comm_fact1;
1191
1192 comm_fact0 = 2.5f * in_row3;
1193 comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
1194
1195 const float8 tmp3 = comm_fact1 + comm_fact2;
1196 const float8 tmp4 = comm_fact2 - comm_fact1;
1197
1198 comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
1199 comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
1200
1201 const float8 tmp5 = comm_fact1 + comm_fact2;
1202 const float8 tmp6 = comm_fact2 - comm_fact1;
1203 const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
1204
1205 // Calculate output rows (reuse comm_fact0 vector)
1206 float8 out0, out1, out2, out3, out4, out5, out6, out7;
1207 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1208 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
1209 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
1210 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
1211 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
1212 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
1213 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
1214 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
1215#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1216
1217 // Store values across the channels
1218 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(float) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
1219
1220 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1221 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1222 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1223 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1224 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1225 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1226 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1227 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
1228
1229#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1230 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1231 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1232 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1233 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1234 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1235 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1236 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1237 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1238 *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1239 *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1240 *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1241 *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1242 *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1243 *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1244 *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1245 *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1246 *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1247 *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1248 *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1249 *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1250 *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1251 *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1252 *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1253 *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1254 *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1255 *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1256 *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1257 *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1258 *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1259 *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1260 *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1261 *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1262 *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1263 *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1264 *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1265 *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1266 *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1267 *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1268 *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1269 *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1270 *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1271 *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1272 *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1273 *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1274 *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1275 *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1276 *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1277 *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1278 *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1279 *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1280 *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1281 *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1282 *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1283 *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1284 *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1285 *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
1286#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001287}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001288#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
1289
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001290/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NCHW
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001291 *
1292 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1293 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001294 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1295 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1296 * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1297 * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001298 *
1299 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1300 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1301 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1302 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1303 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1304 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1305 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1306 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1307 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1308 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1309 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1310 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1311 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1312 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1313 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1314 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1315 */
1316__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
1317 TENSOR3D_DECLARATION(src),
1318 TENSOR3D_DECLARATION(dst))
1319{
1320 int x = get_global_id(0);
1321 int y = get_global_id(1);
1322 int z = get_global_id(2);
1323
1324 // Compute input address
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001325 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001326
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001327 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001328
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001329 // Load input tile
1330#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1331 const float8 in_row0 = vload8(0, (__global float *)(src_addr));
1332#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1333 const float8 in_row0 = (float8)(*((__global float *)(src_addr + 0 * src_stride_y)),
1334 *((__global float *)(src_addr + 1 * src_stride_y)),
1335 *((__global float *)(src_addr + 2 * src_stride_y)),
1336 *((__global float *)(src_addr + 3 * src_stride_y)),
1337 *((__global float *)(src_addr + 4 * src_stride_y)),
1338 *((__global float *)(src_addr + 5 * src_stride_y)),
1339 *((__global float *)(src_addr + 6 * src_stride_y)),
1340 *((__global float *)(src_addr + 7 * src_stride_y)));
1341#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001342 const float8 in_row0 = vload8(0, (__global float *)(src_addr + 0 * src_stride_y));
1343 const float8 in_row1 = vload8(0, (__global float *)(src_addr + 1 * src_stride_y));
1344 const float8 in_row2 = vload8(0, (__global float *)(src_addr + 2 * src_stride_y));
1345 const float8 in_row3 = vload8(0, (__global float *)(src_addr + 3 * src_stride_y));
1346 const float8 in_row4 = vload8(0, (__global float *)(src_addr + 4 * src_stride_y));
1347 const float8 in_row5 = vload8(0, (__global float *)(src_addr + 5 * src_stride_y));
1348 const float8 in_row6 = vload8(0, (__global float *)(src_addr + 6 * src_stride_y));
1349 const float8 in_row7 = vload8(0, (__global float *)(src_addr + 7 * src_stride_y));
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001350#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001351
1352 // Calculate common factors for intermediate tensor
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001353 float8 tmp0 = in_row0;
1354 float8 comm_fact0 = 0.0f;
1355
1356#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1357 comm_fact0 += in_row2 + in_row6 - 4.25f * in_row4;
1358 tmp0 += -in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
1359
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001360 float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
1361 float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
1362
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001363 const float8 tmp1 = comm_fact0 + comm_fact1;
1364 const float8 tmp2 = comm_fact0 - comm_fact1;
1365
1366 comm_fact0 = 2.5f * in_row3;
1367 comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
1368
1369 const float8 tmp3 = comm_fact1 + comm_fact2;
1370 const float8 tmp4 = comm_fact2 - comm_fact1;
1371
1372 comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
1373 comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
1374
1375 const float8 tmp5 = comm_fact1 + comm_fact2;
1376 const float8 tmp6 = comm_fact2 - comm_fact1;
1377 const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001378#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001379
1380 // Calculate output rows (reuse comm_fact0 vector)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001381 float8 out0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001382
1383 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001384
1385#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1386 float8 out1, out2, out3, out4, out5, out6, out7;
1387
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001388 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
1389 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
1390 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
1391 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
1392 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
1393 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
1394 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001395#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001396
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001397 // Store values across the channels
1398 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001399
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001400 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1401 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1402 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1403 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1404 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1405 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1406 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1407 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
1408
1409#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001410 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1411 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1412 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1413 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1414 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1415 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1416 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1417 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1418 *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1419 *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1420 *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1421 *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1422 *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1423 *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1424 *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1425 *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1426 *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1427 *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1428 *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1429 *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1430 *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1431 *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1432 *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1433 *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1434 *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1435 *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1436 *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1437 *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1438 *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1439 *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1440 *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1441 *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1442 *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1443 *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1444 *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1445 *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1446 *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1447 *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1448 *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1449 *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1450 *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1451 *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1452 *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1453 *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1454 *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1455 *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1456 *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1457 *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1458 *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1459 *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1460 *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1461 *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1462 *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1463 *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1464 *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1465 *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001466#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001467}
1468
1469#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1470/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
1471 *
1472 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1473 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1474 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1475 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1476 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1477 *
1478 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1479 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1480 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1481 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1482 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1483 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1484 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1485 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1486 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1487 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1488 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1489 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1490 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1491 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1492 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1493 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1494 */
1495__kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
1496 TENSOR3D_DECLARATION(src),
1497 TENSOR3D_DECLARATION(dst))
1498{
1499 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
1500 src_stride_x,
1501 src_step_x,
1502 src_stride_y,
1503 src_step_y,
1504 src_stride_z,
1505 src_step_z,
1506 src_offset_first_element_in_bytes,
1507 dst_ptr,
1508 dst_stride_x,
1509 dst_step_x,
1510 dst_stride_y,
1511 dst_step_y,
1512 dst_stride_z,
1513 dst_step_z,
1514 dst_offset_first_element_in_bytes);
1515}
1516
1517/** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
1518 *
1519 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1520 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1521 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1522 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1523 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1524 *
1525 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1526 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1527 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1528 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1529 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1530 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1531 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1532 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1533 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1534 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1535 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1536 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1537 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1538 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1539 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1540 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1541 */
1542__kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
1543 TENSOR3D_DECLARATION(src),
1544 TENSOR3D_DECLARATION(dst))
1545{
1546 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
1547 src_stride_x,
1548 src_step_x,
1549 src_stride_y,
1550 src_step_y,
1551 src_stride_z,
1552 src_step_z,
1553 src_offset_first_element_in_bytes,
1554 dst_ptr,
1555 dst_stride_x,
1556 dst_step_x,
1557 dst_stride_y,
1558 dst_step_y,
1559 dst_stride_z,
1560 dst_step_z,
1561 dst_offset_first_element_in_bytes);
1562}
1563
1564/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
1565 *
1566 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1567 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1568 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1569 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1570 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1571 *
1572 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1573 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1574 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1575 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1576 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1577 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1578 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1579 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1580 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1581 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1582 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1583 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1584 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1585 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1586 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1587 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1588 */
1589__kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
1590 TENSOR3D_DECLARATION(src),
1591 TENSOR3D_DECLARATION(dst))
1592{
1593 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
1594 src_stride_x,
1595 src_step_x,
1596 src_stride_y,
1597 src_step_y,
1598 src_stride_z,
1599 src_step_z,
1600 src_offset_first_element_in_bytes,
1601 dst_ptr,
1602 dst_stride_x,
1603 dst_step_x,
1604 dst_stride_y,
1605 dst_step_y,
1606 dst_stride_z,
1607 dst_step_z,
1608 dst_offset_first_element_in_bytes);
1609}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001610
1611/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 when the data layout is NCHW
1612 *
1613 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1614 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1615 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1616 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1617 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1618 *
1619 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1620 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1621 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1622 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1623 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1624 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1625 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1626 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1627 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1628 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1629 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1630 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1631 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1632 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1633 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1634 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1635 */
1636__kernel void winograd_input_transform_4x1_5x1_stepz1_nchw(
1637 TENSOR3D_DECLARATION(src),
1638 TENSOR3D_DECLARATION(dst))
1639{
1640 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
1641 src_stride_x,
1642 src_step_x,
1643 src_stride_y,
1644 src_step_y,
1645 src_stride_z,
1646 src_step_z,
1647 src_offset_first_element_in_bytes,
1648 dst_ptr,
1649 dst_stride_x,
1650 dst_step_x,
1651 dst_stride_y,
1652 dst_step_y,
1653 dst_stride_z,
1654 dst_step_z,
1655 dst_offset_first_element_in_bytes);
1656}
1657
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001658#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
1659/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1 for data layout NHWC
1660 *
1661 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1662 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1663 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
1664 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1665 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1666 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1667 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1668 *
1669 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1670 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1671 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1672 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1673 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1674 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1675 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1676 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1677 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1678 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1679 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1680 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1681 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1682 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1683 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1684 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1685 */
1686__kernel void winograd_input_transform_4x1_3x1_stepz1_nhwc(
1687 TENSOR3D_DECLARATION(src),
1688 TENSOR3D_DECLARATION(dst))
1689{
1690 winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
1691 src_stride_x,
1692 src_step_x,
1693 src_stride_y,
1694 src_step_y,
1695 src_stride_z,
1696 src_step_z,
1697 src_offset_first_element_in_bytes,
1698 dst_ptr,
1699 dst_stride_x,
1700 dst_step_x,
1701 dst_stride_y,
1702 dst_step_y,
1703 dst_stride_z,
1704 dst_step_z,
1705 dst_offset_first_element_in_bytes);
1706}
1707
1708/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 for data layout NHWC
1709 *
1710 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1711 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1712 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
1713 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1714 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1715 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1716 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1717 *
1718 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1719 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1720 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1721 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1722 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1723 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1724 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1725 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1726 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1727 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1728 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1729 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1730 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1731 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1732 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1733 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1734 */
1735__kernel void winograd_input_transform_4x1_5x1_stepz1_nhwc(
1736 TENSOR3D_DECLARATION(src),
1737 TENSOR3D_DECLARATION(dst))
1738{
1739 winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
1740 src_stride_x,
1741 src_step_x,
1742 src_stride_y,
1743 src_step_y,
1744 src_stride_z,
1745 src_step_z,
1746 src_offset_first_element_in_bytes,
1747 dst_ptr,
1748 dst_stride_x,
1749 dst_step_x,
1750 dst_stride_y,
1751 dst_step_y,
1752 dst_stride_z,
1753 dst_step_z,
1754 dst_offset_first_element_in_bytes);
1755}
1756#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001757#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1758
1759#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1760/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
1761 *
1762 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1763 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1764 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1765 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1766 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1767 *
1768 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1769 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1770 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1771 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1772 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1773 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1774 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1775 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1776 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1777 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1778 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1779 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1780 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1781 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1782 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1783 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1784 */
1785__kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
1786 TENSOR3D_DECLARATION(src),
1787 TENSOR3D_DECLARATION(dst))
1788{
1789 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
1790 src_stride_x,
1791 src_step_x,
1792 src_stride_y,
1793 src_step_y,
1794 src_stride_z,
1795 src_step_z,
1796 src_offset_first_element_in_bytes,
1797 dst_ptr,
1798 dst_stride_x,
1799 dst_step_x,
1800 dst_stride_y,
1801 dst_step_y,
1802 dst_stride_z,
1803 dst_step_z,
1804 dst_offset_first_element_in_bytes);
1805}
1806
1807/** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
1808 *
1809 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1810 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1811 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1812 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1813 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1814 *
1815 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1816 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1817 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1818 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1819 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1820 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1821 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1822 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1823 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1824 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1825 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1826 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1827 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1828 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1829 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1830 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1831 */
1832__kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
1833 TENSOR3D_DECLARATION(src),
1834 TENSOR3D_DECLARATION(dst))
1835{
1836 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
1837 src_stride_x,
1838 src_step_x,
1839 src_stride_y,
1840 src_step_y,
1841 src_stride_z,
1842 src_step_z,
1843 src_offset_first_element_in_bytes,
1844 dst_ptr,
1845 dst_stride_x,
1846 dst_step_x,
1847 dst_stride_y,
1848 dst_step_y,
1849 dst_stride_z,
1850 dst_step_z,
1851 dst_offset_first_element_in_bytes);
1852}
1853
1854/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
1855 *
1856 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1857 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1858 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1859 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
1860 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1861 *
1862 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1863 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1864 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1865 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1866 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1867 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1868 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1869 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1870 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1871 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1872 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1873 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1874 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1875 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1876 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1877 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1878 */
1879__kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
1880 TENSOR3D_DECLARATION(src),
1881 TENSOR3D_DECLARATION(dst))
1882{
1883 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
1884 src_stride_x,
1885 src_step_x,
1886 src_stride_y,
1887 src_step_y,
1888 src_stride_z,
1889 src_step_z,
1890 src_offset_first_element_in_bytes,
1891 dst_ptr,
1892 dst_stride_x,
1893 dst_step_x,
1894 dst_stride_y,
1895 dst_step_y,
1896 dst_stride_z,
1897 dst_step_z,
1898 dst_offset_first_element_in_bytes);
1899}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001900
1901/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4
1902 *
1903 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1904 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1905 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1906 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
1907 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1908 *
1909 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1910 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1911 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1912 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1913 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1914 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1915 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1916 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1917 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1918 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1919 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1920 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1921 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1922 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1923 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1924 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1925 */
1926__kernel void winograd_input_transform_1x4_1x5_stepz1_nchw(
1927 TENSOR3D_DECLARATION(src),
1928 TENSOR3D_DECLARATION(dst))
1929{
1930 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
1931 src_stride_x,
1932 src_step_x,
1933 src_stride_y,
1934 src_step_y,
1935 src_stride_z,
1936 src_step_z,
1937 src_offset_first_element_in_bytes,
1938 dst_ptr,
1939 dst_stride_x,
1940 dst_step_x,
1941 dst_stride_y,
1942 dst_step_y,
1943 dst_stride_z,
1944 dst_step_z,
1945 dst_offset_first_element_in_bytes);
1946}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001947
1948#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001949/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4 for data layout NHWC
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001950 *
1951 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001952 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1953 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001954 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001955 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001956 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001957 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001958 *
1959 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1960 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1961 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1962 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1963 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1964 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1965 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1966 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1967 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1968 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1969 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1970 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1971 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1972 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1973 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1974 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1975 */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001976__kernel void winograd_input_transform_1x4_1x3_stepz1_nhwc(
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001977 TENSOR3D_DECLARATION(src),
1978 TENSOR3D_DECLARATION(dst))
1979{
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001980 winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
1981 src_stride_x,
1982 src_step_x,
1983 src_stride_y,
1984 src_step_y,
1985 src_stride_z,
1986 src_step_z,
1987 src_offset_first_element_in_bytes,
1988 dst_ptr,
1989 dst_stride_x,
1990 dst_step_x,
1991 dst_stride_y,
1992 dst_step_y,
1993 dst_stride_z,
1994 dst_step_z,
1995 dst_offset_first_element_in_bytes);
1996}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001997
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001998/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4 for data layout NHWC
1999 *
2000 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2001 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2002 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2003 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2004 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2005 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2006 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
2007 *
2008 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
2009 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2010 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2011 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2012 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2013 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2014 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2015 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2016 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2017 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2018 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2019 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2020 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2021 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2022 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2023 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2024 */
2025__kernel void winograd_input_transform_1x4_1x5_stepz1_nhwc(
2026 TENSOR3D_DECLARATION(src),
2027 TENSOR3D_DECLARATION(dst))
2028{
2029 winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
2030 src_stride_x,
2031 src_step_x,
2032 src_stride_y,
2033 src_step_y,
2034 src_stride_z,
2035 src_step_z,
2036 src_offset_first_element_in_bytes,
2037 dst_ptr,
2038 dst_stride_x,
2039 dst_step_x,
2040 dst_stride_y,
2041 dst_step_y,
2042 dst_stride_z,
2043 dst_step_z,
2044 dst_offset_first_element_in_bytes);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002045}
2046#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002047#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002048#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)