blob: 73da005996f2b7854d77a63778b643f9aac8dfe2 [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26#if defined(SRC_DIM_Z)
27
28/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 2x2/2x1/1x2
29 *
30 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
31 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
32 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
33 *
34 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
35 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
36 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
37 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
38 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
39 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
40 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
41 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
42 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
43 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
44 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
45 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
46 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
47 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
48 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
49 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
50 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
51 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
52 */
53__kernel void winograd_filter_transform_2x2_3x3_nchw(
54 TENSOR4D_DECLARATION(src),
55 TENSOR3D_DECLARATION(dst))
56{
57 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
58
59 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
60
61 // Load the values from the input tensor
62#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
63 float3 w0 = vload3(0, (__global float *)(src_addr));
64#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
65 float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
66 *((__global float *)(src_addr + 1 * src_stride_y)),
67 *((__global float *)(src_addr + 2 * src_stride_y)));
68#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +010069 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
70 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
71 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010072#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
73
74 // Row 0
75 float4 out0 = 0.0f;
76 out0.s0 = (w0.s0);
77 out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
78 out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
79 out0.s3 = (w0.s2);
80
81#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
82 // Row 1
83 float4 out1 = 0.0f;
84 out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
85 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
86 out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
87 out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
88
89 // Row 2
90 float4 out2 = 0.0f;
91 out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
92 out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
93 out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
94 out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
95
96 // Row 3
97 float4 out3 = 0.0f;
98 out3.s0 = (w2.s0);
99 out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
100 out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
101 out3.s3 = (w2.s2);
102#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
103
104 int z = get_global_id(2);
105 int x0 = z / SRC_DIM_Z; // idx filter
106 int y0 = z % SRC_DIM_Z; // idx channel
107
108 // Get output address
109 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
110
111 // Store the values across the channels
112 // 16 channels for 3x3 kernels
113 // 4 channels for 3x1 or 1x3 kernels
114 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
115 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
116 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
117 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
118
119#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
120 *(__global float *)(dst_addr + 4 * dst_stride_z) = out1.s0;
121 *(__global float *)(dst_addr + 5 * dst_stride_z) = out1.s1;
122 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s2;
123 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s3;
124 *(__global float *)(dst_addr + 8 * dst_stride_z) = out2.s0;
125 *(__global float *)(dst_addr + 9 * dst_stride_z) = out2.s1;
126 *(__global float *)(dst_addr + 10 * dst_stride_z) = out2.s2;
127 *(__global float *)(dst_addr + 11 * dst_stride_z) = out2.s3;
128 *(__global float *)(dst_addr + 12 * dst_stride_z) = out3.s0;
129 *(__global float *)(dst_addr + 13 * dst_stride_z) = out3.s1;
130 *(__global float *)(dst_addr + 14 * dst_stride_z) = out3.s2;
131 *(__global float *)(dst_addr + 15 * dst_stride_z) = out3.s3;
132#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
133}
134
135/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 4x4/4x1/1x4
136 *
137 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
138 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
139 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
140 *
141 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
142 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
143 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
144 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
145 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
146 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
147 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
148 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
149 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
150 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
151 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
152 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
153 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
154 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
155 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
156 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
157 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
158 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
159 */
160__kernel void winograd_filter_transform_4x4_3x3_nchw(
161 TENSOR4D_DECLARATION(src),
162 TENSOR3D_DECLARATION(dst))
163{
164 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
165
166 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
167
168 // Load the values from the input tensor
169#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
170 float3 w0 = vload3(0, (__global float *)(src_addr));
171#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
172 float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
173 *((__global float *)(src_addr + 1 * src_stride_y)),
174 *((__global float *)(src_addr + 2 * src_stride_y)));
175#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100176 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
177 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
178 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100179#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
180
181 // Row 0
182 float8 out0 = 0.0f;
183 out0.s0 = (w0.s0) / 16.f;
184 out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
185 out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
186 out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
187 out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
188 out0.s5 = (w0.s2) / 4.f;
189
190#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
191 // Row 1
192 float8 out1 = 0.0f;
193 out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
194 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
195 out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
196 out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
197 out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
198 out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
199
200 // Row 2
201 float8 out2 = 0.0f;
202 out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
203 out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
204 out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
205 out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
206 out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
207 out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
208
209 // Row 3
210 float8 out3 = 0.0f;
211 out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
212 out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
213 out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
214 out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
215 out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
216 out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
217
218 // Row 4
219 float8 out4 = 0.0f;
220 out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
221 out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
222 out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
223 out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
224 out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
225 out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
226
227 // Row 5
228 float8 out5 = 0.0f;
229 out5.s0 = (w2.s0) / 4.f;
230 out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
231 out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
232 out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
233 out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
234 out5.s5 = (w2.s2);
235#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
236
237 int z = get_global_id(2);
238 int x0 = z / SRC_DIM_Z; // idx filter
239 int y0 = z % SRC_DIM_Z; // idx channel
240
241 // Get output address
242 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
243
244 // Store the values across the channels
245 // 36 channels for 3x3 kernels
246 // 6 channels for 3x1 or 1x3 kernels
247 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
248 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
249 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
250 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
251 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
252 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
253
254#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
255 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s0;
256 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s1;
257 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s2;
258 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s3;
259 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s4;
260 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s5;
261 *(__global float *)(dst_addr + 12 * dst_stride_z) = out2.s0;
262 *(__global float *)(dst_addr + 13 * dst_stride_z) = out2.s1;
263 *(__global float *)(dst_addr + 14 * dst_stride_z) = out2.s2;
264 *(__global float *)(dst_addr + 15 * dst_stride_z) = out2.s3;
265 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s4;
266 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s5;
267 *(__global float *)(dst_addr + 18 * dst_stride_z) = out3.s0;
268 *(__global float *)(dst_addr + 19 * dst_stride_z) = out3.s1;
269 *(__global float *)(dst_addr + 20 * dst_stride_z) = out3.s2;
270 *(__global float *)(dst_addr + 21 * dst_stride_z) = out3.s3;
271 *(__global float *)(dst_addr + 22 * dst_stride_z) = out3.s4;
272 *(__global float *)(dst_addr + 23 * dst_stride_z) = out3.s5;
273 *(__global float *)(dst_addr + 24 * dst_stride_z) = out4.s0;
274 *(__global float *)(dst_addr + 25 * dst_stride_z) = out4.s1;
275 *(__global float *)(dst_addr + 26 * dst_stride_z) = out4.s2;
276 *(__global float *)(dst_addr + 27 * dst_stride_z) = out4.s3;
277 *(__global float *)(dst_addr + 28 * dst_stride_z) = out4.s4;
278 *(__global float *)(dst_addr + 29 * dst_stride_z) = out4.s5;
279 *(__global float *)(dst_addr + 30 * dst_stride_z) = out5.s0;
280 *(__global float *)(dst_addr + 31 * dst_stride_z) = out5.s1;
281 *(__global float *)(dst_addr + 32 * dst_stride_z) = out5.s2;
282 *(__global float *)(dst_addr + 33 * dst_stride_z) = out5.s3;
283 *(__global float *)(dst_addr + 34 * dst_stride_z) = out5.s4;
284 *(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
285#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
286}
287
Giorgio Arena149fdf32018-07-04 17:03:33 +0100288/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NHWC and the output tile is 4x4/4x1/1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100289 *
290 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
Giorgio Arena149fdf32018-07-04 17:03:33 +0100291 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
292 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100293 *
294 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
295 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
296 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
297 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
298 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
299 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
300 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
301 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
302 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
303 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
304 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
305 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
306 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
307 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
308 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
309 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
310 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
311 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
312 */
313__kernel void winograd_filter_transform_4x4_3x3_nhwc(
314 TENSOR4D_DECLARATION(src),
315 TENSOR3D_DECLARATION(dst))
316{
317 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
318
319 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
320
321 // Load the values from the input tensor
Giorgio Arena149fdf32018-07-04 17:03:33 +0100322#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
323 float w00 = *((__global float *)(src_addr + 0 * src_stride_z));
324 float w01 = *((__global float *)(src_addr + 1 * src_stride_z));
325 float w02 = *((__global float *)(src_addr + 2 * src_stride_z));
326#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
327 float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
328 float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
329 float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
330#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
331 float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
332 float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
333 float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
334 float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
335 float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
336 float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
337#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
338#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100339
340 // Row 0
Giorgio Arena149fdf32018-07-04 17:03:33 +0100341 float out00, out01, out02, out03, out04, out05;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100342 out00 = (w00) / 16.f;
343 out01 = (-w00 - w01 - w02) / 24.f;
344 out02 = (-w00 + w01 - w02) / 24.f;
345 out03 = (w00 + 2.f * w01 + 4.f * w02) / 96.f;
346 out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f;
347 out05 = (w02) / 4.f;
348
Giorgio Arena149fdf32018-07-04 17:03:33 +0100349#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100350 // Row 1
Giorgio Arena149fdf32018-07-04 17:03:33 +0100351 float out10, out11, out12, out13, out14, out15;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100352 out10 = (-w00 - w10 - w20) / 24.f;
353 out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f;
354 out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f;
355 out13 = (-w00 - w10 - w20 + 2.f * (-w01 - w11 - w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
356 out14 = (-w00 - w10 - w20 + 2.f * (w01 + w11 + w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
357 out15 = (-w02 - w12 - w22) / 6.f;
358
359 // Row 2
Giorgio Arena149fdf32018-07-04 17:03:33 +0100360 float out20, out21, out22, out23, out24, out25;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100361 out20 = (-w00 + w10 - w20) / 24.f;
362 out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f;
363 out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f;
364 out23 = (-w00 + w10 - w20 + 2.f * (-w01 + w11 - w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
365 out24 = (-w00 + w10 - w20 + 2.f * (w01 - w11 + w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
366 out25 = (-w02 + w12 - w22) / 6.f;
367
368 // Row 3
Giorgio Arena149fdf32018-07-04 17:03:33 +0100369 float out30, out31, out32, out33, out34, out35;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100370 out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f;
371 out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
372 out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
373 out33 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (w01 + 2.f * w11 + 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
374 out34 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (-w01 - 2.f * w11 - 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
375 out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f;
376
377 // Row 4
Giorgio Arena149fdf32018-07-04 17:03:33 +0100378 float out40, out41, out42, out43, out44, out45;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100379 out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f;
380 out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
381 out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
382 out43 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (w01 - 2.f * w11 + 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
383 out44 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (-w01 + 2.f * w11 - 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
384 out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f;
385
386 // Row 5
Giorgio Arena149fdf32018-07-04 17:03:33 +0100387 float out50, out51, out52, out53, out54, out55;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100388 out50 = (w20) / 4.f;
389 out51 = (-w20 - w21 - w22) / 6.f;
390 out52 = (-w20 + w21 - w22) / 6.f;
391 out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f;
392 out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f;
393 out55 = (w22);
Giorgio Arena149fdf32018-07-04 17:03:33 +0100394#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100395
396 int x0 = get_global_id(2); // idx filter
397 int y0 = get_global_id(0); // idx channel
398
399 // Get output address
Giorgio Arena149fdf32018-07-04 17:03:33 +0100400 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100401
402 // Store the values across the channels
Giorgio Arena149fdf32018-07-04 17:03:33 +0100403 // 36 channels for 3x3 kernels
404 // 6 channels for 3x1 or 1x3 kernels
405 *(__global float *)(dst_addr + 0 * dst_stride_z) = out00;
406 *(__global float *)(dst_addr + 1 * dst_stride_z) = out01;
407 *(__global float *)(dst_addr + 2 * dst_stride_z) = out02;
408 *(__global float *)(dst_addr + 3 * dst_stride_z) = out03;
409 *(__global float *)(dst_addr + 4 * dst_stride_z) = out04;
410 *(__global float *)(dst_addr + 5 * dst_stride_z) = out05;
411#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100412 *(__global float *)(dst_addr + 6 * dst_stride_z) = out10;
413 *(__global float *)(dst_addr + 7 * dst_stride_z) = out11;
414 *(__global float *)(dst_addr + 8 * dst_stride_z) = out12;
415 *(__global float *)(dst_addr + 9 * dst_stride_z) = out13;
416 *(__global float *)(dst_addr + 10 * dst_stride_z) = out14;
417 *(__global float *)(dst_addr + 11 * dst_stride_z) = out15;
418 *(__global float *)(dst_addr + 12 * dst_stride_z) = out20;
419 *(__global float *)(dst_addr + 13 * dst_stride_z) = out21;
420 *(__global float *)(dst_addr + 14 * dst_stride_z) = out22;
421 *(__global float *)(dst_addr + 15 * dst_stride_z) = out23;
422 *(__global float *)(dst_addr + 16 * dst_stride_z) = out24;
423 *(__global float *)(dst_addr + 17 * dst_stride_z) = out25;
424 *(__global float *)(dst_addr + 18 * dst_stride_z) = out30;
425 *(__global float *)(dst_addr + 19 * dst_stride_z) = out31;
426 *(__global float *)(dst_addr + 20 * dst_stride_z) = out32;
427 *(__global float *)(dst_addr + 21 * dst_stride_z) = out33;
428 *(__global float *)(dst_addr + 22 * dst_stride_z) = out34;
429 *(__global float *)(dst_addr + 23 * dst_stride_z) = out35;
430 *(__global float *)(dst_addr + 24 * dst_stride_z) = out40;
431 *(__global float *)(dst_addr + 25 * dst_stride_z) = out41;
432 *(__global float *)(dst_addr + 26 * dst_stride_z) = out42;
433 *(__global float *)(dst_addr + 27 * dst_stride_z) = out43;
434 *(__global float *)(dst_addr + 28 * dst_stride_z) = out44;
435 *(__global float *)(dst_addr + 29 * dst_stride_z) = out45;
436 *(__global float *)(dst_addr + 30 * dst_stride_z) = out50;
437 *(__global float *)(dst_addr + 31 * dst_stride_z) = out51;
438 *(__global float *)(dst_addr + 32 * dst_stride_z) = out52;
439 *(__global float *)(dst_addr + 33 * dst_stride_z) = out53;
440 *(__global float *)(dst_addr + 34 * dst_stride_z) = out54;
441 *(__global float *)(dst_addr + 35 * dst_stride_z) = out55;
Giorgio Arena149fdf32018-07-04 17:03:33 +0100442#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100443}
Giorgio Arena149fdf32018-07-04 17:03:33 +0100444
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100445/** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NCHW and the output tile is 4x4/4x1 or 1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100446 *
447 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
448 *
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100449 * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
450 * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
451 *
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100452 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
453 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
454 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
455 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
456 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
457 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
458 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
459 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
460 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
461 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
462 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
463 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
464 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
465 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
466 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
467 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
468 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
469 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
470 */
471__kernel void winograd_filter_transform_4x4_5x5_nchw(
472 TENSOR4D_DECLARATION(src),
473 TENSOR3D_DECLARATION(dst))
474{
475 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
476
477 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
478
479 // Load the values from the input tensor
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100480#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
481 float4 w00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
482 float w01 = *((__global float *)(src_addr + 0 * src_stride_y) + 4);
483#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
484 float4 w00 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
485 *((__global float *)(src_addr + 1 * src_stride_y)),
486 *((__global float *)(src_addr + 2 * src_stride_y)),
487 *((__global float *)(src_addr + 3 * src_stride_y)));
488 float w01 = *((__global float *)(src_addr + 4 * src_stride_y));
489#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
490 float4 w00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
491 float w01 = *((__global float *)(src_addr + 0 * src_stride_y) + 4);
492 float4 w10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
493 float w11 = *((__global float *)(src_addr + 1 * src_stride_y) + 4);
494 float4 w20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
495 float w21 = *((__global float *)(src_addr + 2 * src_stride_y) + 4);
496 float4 w30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
497 float w31 = *((__global float *)(src_addr + 3 * src_stride_y) + 4);
498 float4 w40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
499 float w41 = *((__global float *)(src_addr + 4 * src_stride_y) + 4);
500#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100501
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100502 // Transform the input tile
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100503
504 // Row 0
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100505 float8 out0 = 0.0f;
506 out0.s0 = w00.s0;
507 out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
508 out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
509 out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
510 out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
511 out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
512 out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
513 out0.s7 = w01;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100514
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100515#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100516 // Row 1
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100517 float8 out1 = 0.0f;
518 out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
519 out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
520 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
521 out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
522 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
523 out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
524 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
525 out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
526 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
527 out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
528 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
529 out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
530 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
531 out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100532
533 // Row 2
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100534 float8 out2 = 0.0f;
535 out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
536 out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
537 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
538 out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
539 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
540 out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
541 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
542 out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
543 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
544 out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
545 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
546 out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
547 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
548 out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100549
550 // Row 3
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100551 float8 out3 = 0.0f;
552 out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
553 out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
554 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
555 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
556 out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
557 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
558 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
559 out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
560 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
561 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
562 out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
563 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
564 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
565 out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
566 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
567 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
568 out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
569 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
570 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
571 out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100572
573 // Row 4
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100574 float8 out4 = 0.0f;
575 out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
576 out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
577 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
578 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
579 out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
580 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
581 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
582 out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
583 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
584 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
585 out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
586 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
587 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
588 out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
589 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
590 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
591 out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
592 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
593 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
594 out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100595
596 // Row 5
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100597 float8 out5 = 0.0f;
598 out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
599 out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
600 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
601 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
602 out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
603 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
604 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
605 out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
606 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
607 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
608 out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
609 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
610 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
611 out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
612 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
613 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
614 out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
615 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
616 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
617 out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100618
619 // Row 6
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100620 float8 out6 = 0.0f;
621 out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
622 out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
623 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
624 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
625 out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
626 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
627 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
628 out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
629 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
630 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
631 out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
632 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
633 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
634 out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
635 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
636 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
637 out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
638 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
639 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
640 out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100641
642 // Row 7
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100643 float8 out7 = 0.0f;
644 out7.s0 = w40.s0;
645 out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
646 out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
647 out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
648 out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
649 out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
650 out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
651 out7.s7 = w41;
652#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100653
654 int z = get_global_id(2);
655 int x0 = z / SRC_DIM_Z; // idx filter
656 int y0 = z % SRC_DIM_Z; // idx channel
657
658 // Get output address
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100659 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100660
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100661 // Store the values across the channels
662 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
663 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
664 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
665 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
666 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
667 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
668 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
669 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
670
671#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100672 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
673 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
674 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
675 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
676 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
677 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
678 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
679 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
680 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
681 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
682 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
683 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
684 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
685 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
686 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
687 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
688 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
689 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
690 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
691 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
692 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
693 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
694 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
695 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
696 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
697 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
698 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
699 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
700 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
701 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
702 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
703 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
704 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
705 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
706 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
707 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
708 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
709 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
710 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
711 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
712 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
713 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
714 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
715 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
716 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
717 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
718 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
719 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
720 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
721 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
722 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
723 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
724 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
725 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
726 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
727 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100728#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100729}
730
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100731/** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NHWC and the output tile is 4x4/4x1 or 1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100732 *
733 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100734 * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
735 * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100736 *
737 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
738 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
739 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
740 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
741 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
742 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
743 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
744 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
745 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
746 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
747 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
748 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
749 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
750 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
751 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
752 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
753 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
754 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
755 */
756__kernel void winograd_filter_transform_4x4_5x5_nhwc(
757 TENSOR4D_DECLARATION(src),
758 TENSOR3D_DECLARATION(dst))
759{
760 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
761
762 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(float) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
763
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100764#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100765 // Load the values from the input tensor
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100766 float w00 = *((__global float *)(src_addr + 0 * src_stride_z));
767 float w01 = *((__global float *)(src_addr + 1 * src_stride_z));
768 float w02 = *((__global float *)(src_addr + 2 * src_stride_z));
769 float w03 = *((__global float *)(src_addr + 3 * src_stride_z));
770 float w04 = *((__global float *)(src_addr + 4 * src_stride_z));
771#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
772 // Load the values from the input tensor
773 float w00 = *((__global float *)(src_addr + 0 * src_stride_y));
774 float w01 = *((__global float *)(src_addr + 1 * src_stride_y));
775 float w02 = *((__global float *)(src_addr + 2 * src_stride_y));
776 float w03 = *((__global float *)(src_addr + 3 * src_stride_y));
777 float w04 = *((__global float *)(src_addr + 4 * src_stride_y));
778#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
779
780#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100781 float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
782 float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
783 float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
784 float w13 = *((__global float *)(src_addr + 1 * src_stride_z + 3 * src_stride_y));
785 float w14 = *((__global float *)(src_addr + 1 * src_stride_z + 4 * src_stride_y));
786 float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
787 float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
788 float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
789 float w23 = *((__global float *)(src_addr + 2 * src_stride_z + 3 * src_stride_y));
790 float w24 = *((__global float *)(src_addr + 2 * src_stride_z + 4 * src_stride_y));
791 float w30 = *((__global float *)(src_addr + 3 * src_stride_z + 0 * src_stride_y));
792 float w31 = *((__global float *)(src_addr + 3 * src_stride_z + 1 * src_stride_y));
793 float w32 = *((__global float *)(src_addr + 3 * src_stride_z + 2 * src_stride_y));
794 float w33 = *((__global float *)(src_addr + 3 * src_stride_z + 3 * src_stride_y));
795 float w34 = *((__global float *)(src_addr + 3 * src_stride_z + 4 * src_stride_y));
796 float w40 = *((__global float *)(src_addr + 4 * src_stride_z + 0 * src_stride_y));
797 float w41 = *((__global float *)(src_addr + 4 * src_stride_z + 1 * src_stride_y));
798 float w42 = *((__global float *)(src_addr + 4 * src_stride_z + 2 * src_stride_y));
799 float w43 = *((__global float *)(src_addr + 4 * src_stride_z + 3 * src_stride_y));
800 float w44 = *((__global float *)(src_addr + 4 * src_stride_z + 4 * src_stride_y));
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100801#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100802
803 // Row 0
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100804 float8 out0 = 0.0f;
805 out0.s0 = w00;
806 out0.s1 = -2.f * (w00 + w01 + w02 + w03 + w04) / 9.f;
807 out0.s2 = -2.f * (w00 - w01 + w02 - w03 + w04) / 9.f;
808 out0.s3 = (w00 + 2.f * w01 + 4.f * w02 + 8.f * w03 + 16.f * w04) / 90.f;
809 out0.s4 = (w00 - 2.f * w01 + 4.f * w02 - 8.f * w03 + 16.f * w04) / 90.f;
810 out0.s5 = (16.f * w00 + 8.f * w01 + 4.f * w02 + 2.f * w03 + w04) / 180.f;
811 out0.s6 = (16.f * w00 - 8.f * w01 + 4.f * w02 - 2.f * w03 + w04) / 180.f;
812 out0.s7 = w04;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100813
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100814#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100815 // Row 1
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100816 float8 out1 = 0.0f;
817 out1.s0 = -2.f * (w00 + w10 + w20 + w30 + w40) / 9.f;
818 out1.s1 = 4.f * ((w00 + w10 + w20 + w30 + w40) + (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) + (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
819 out1.s2 = 4.f * ((w00 + w10 + w20 + w30 + w40) - (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) - (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
820 out1.s3 = -((w00 + w10 + w20 + w30 + w40) + 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
821 (w04 + w14 + w24 + w34 + w44)) / 405.f;
822 out1.s4 = -((w00 + w10 + w20 + w30 + w40) - 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
823 (w04 + w14 + w24 + w34 + w44)) / 405.f;
824 out1.s5 = -(16.f * (w00 + w10 + w20 + w30 + w40) + 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 2.f * (w03 + w13 + w23 + w33 + w43) +
825 (w04 + w14 + w24 + w34 + w44)) / 810.f;
826 out1.s6 = -(16.f * (w00 + w10 + w20 + w30 + w40) - 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 2.f * (w03 + w13 + w23 + w33 + w43) +
827 (w04 + w14 + w24 + w34 + w44)) / 810.f;
828 out1.s7 = -2.f * (w04 + w14 + w24 + w34 + w44) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100829
830 // Row 2
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100831 float8 out2 = 0.0f;
832 out2.s0 = -2.f * (w00 - w10 + w20 - w30 + w40) / 9.f;
833 out2.s1 = 4.f * ((w00 - w10 + w20 - w30 + w40) + (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) + (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
834 out2.s2 = 4.f * ((w00 - w10 + w20 - w30 + w40) - (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) - (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
835 out2.s3 = -((w00 - w10 + w20 - w30 + w40) + 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
836 (w04 - w14 + w24 - w34 + w44)) / 405.f;
837 out2.s4 = -((w00 - w10 + w20 - w30 + w40) - 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
838 (w04 - w14 + w24 - w34 + w44)) / 405.f;
839 out2.s5 = -(16.f * (w00 - w10 + w20 - w30 + w40) + 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 2.f * (w03 - w13 + w23 - w33 + w43) +
840 (w04 - w14 + w24 - w34 + w44)) / 810.f;
841 out2.s6 = -(16.f * (w00 - w10 + w20 - w30 + w40) - 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 2.f * (w03 - w13 + w23 - w33 + w43) +
842 (w04 - w14 + w24 - w34 + w44)) / 810.f;
843 out2.s7 = -2.f * (w04 - w14 + w24 - w34 + w44) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100844
845 // Row 3
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100846 float8 out3 = 0.0f;
847 out3.s0 = (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) / 90.f;
848 out3.s1 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) +
849 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
850 out3.s2 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) -
851 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
852 out3.s3 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
853 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 8.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
854 out3.s4 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
855 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 8.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
856 out3.s5 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
857 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
858 out3.s6 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
859 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
860 out3.s7 = (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100861
862 // Row 4
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100863 float8 out4 = 0.0f;
864 out4.s0 = (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) / 90.f;
865 out4.s1 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) +
866 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
867 out4.s2 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) -
868 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
869 out4.s3 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
870 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 8.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
871 out4.s4 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
872 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 8.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
873 out4.s5 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
874 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
875 out4.s6 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
876 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
877 out4.s7 = (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100878
879 // Row 5
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100880 float8 out5 = 0.0f;
881 out5.s0 = (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) / 180.f;
882 out5.s1 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) +
883 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
884 out5.s2 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) -
885 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
886 out5.s3 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
887 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 8.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
888 out5.s4 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
889 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 8.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
890 out5.s5 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
891 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
892 out5.s6 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
893 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
894 out5.s7 = (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100895
896 // Row 6
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100897 float8 out6 = 0.0f;
898 out6.s0 = (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) / 180.f;
899 out6.s1 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) +
900 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
901 out6.s2 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) -
902 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
903 out6.s3 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
904 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 8.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
905 out6.s4 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
906 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 8.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
907 out6.s5 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
908 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
909 out6.s6 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
910 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
911 out6.s7 = (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100912
913 // Row 7
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100914 float8 out7 = 0.0f;
915 out7.s0 = w40;
916 out7.s1 = -2.f * (w40 + w41 + w42 + w43 + w44) / 9.f;
917 out7.s2 = -2.f * (w40 - w41 + w42 - w43 + w44) / 9.f;
918 out7.s3 = (w40 + 2.f * w41 + 4.f * w42 + 8.f * w43 + 16.f * w44) / 90.f;
919 out7.s4 = (w40 - 2.f * w41 + 4.f * w42 - 8.f * w43 + 16.f * w44) / 90.f;
920 out7.s5 = (16.f * w40 + 8.f * w41 + 4.f * w42 + 2.f * w43 + w44) / 180.f;
921 out7.s6 = (16.f * w40 - 8.f * w41 + 4.f * w42 - 2.f * w43 + w44) / 180.f;
922 out7.s7 = w44;
923#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100924
925 int x0 = get_global_id(2); // idx filter
926 int y0 = get_global_id(0); // idx channel
927
928 // Get output address
929 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
930
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100931 // Store the values across the channels
932 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
933 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
934 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
935 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
936 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
937 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
938 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
939 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
940
941#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100942 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
943 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
944 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
945 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
946 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
947 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
948 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
949 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
950 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
951 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
952 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
953 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
954 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
955 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
956 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
957 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
958 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
959 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
960 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
961 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
962 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
963 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
964 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
965 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
966 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
967 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
968 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
969 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
970 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
971 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
972 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
973 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
974 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
975 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
976 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
977 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
978 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
979 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
980 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
981 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
982 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
983 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
984 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
985 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
986 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
987 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
988 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
989 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
990 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
991 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
992 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
993 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
994 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
995 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
996 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
997 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100998#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100999}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001000#endif // defined(SRC_DIM_Z)
1001
1002#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1003/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 2x1
1004 *
1005 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1006 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
1007 *
1008 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1009 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1010 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1011 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1012 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1013 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1014 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1015 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1016 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1017 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1018 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1019 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1020 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1021 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1022 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1023 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1024 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1025 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1026 */
1027__kernel void winograd_filter_transform_2x1_3x1_nchw(
1028 TENSOR4D_DECLARATION(src),
1029 TENSOR3D_DECLARATION(dst))
1030{
1031 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
1032 src_stride_x,
1033 src_step_x,
1034 src_stride_y,
1035 src_step_y,
1036 src_stride_z,
1037 src_step_z,
1038 src_stride_w,
1039 src_step_w,
1040 src_offset_first_element_in_bytes,
1041 dst_ptr,
1042 dst_stride_x,
1043 dst_step_x,
1044 dst_stride_y,
1045 dst_step_y,
1046 dst_stride_z,
1047 dst_step_z,
1048 dst_offset_first_element_in_bytes);
1049}
1050
1051/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 4x1
1052 *
1053 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1054 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
1055 *
1056 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1057 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1058 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1059 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1060 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1061 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1062 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1063 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1064 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1065 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1066 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1067 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1068 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1069 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1070 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1071 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1072 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1073 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1074 */
1075__kernel void winograd_filter_transform_4x1_3x1_nchw(
1076 TENSOR4D_DECLARATION(src),
1077 TENSOR3D_DECLARATION(dst))
1078{
1079 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
1080 src_stride_x,
1081 src_step_x,
1082 src_stride_y,
1083 src_step_y,
1084 src_stride_z,
1085 src_step_z,
1086 src_stride_w,
1087 src_step_w,
1088 src_offset_first_element_in_bytes,
1089 dst_ptr,
1090 dst_stride_x,
1091 dst_step_x,
1092 dst_stride_y,
1093 dst_step_y,
1094 dst_stride_z,
1095 dst_step_z,
1096 dst_offset_first_element_in_bytes);
1097}
1098
1099/** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NCHW and the output tile is 4x1
1100 *
1101 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1102 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
1103 *
1104 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1105 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1106 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1107 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1108 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1109 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1110 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1111 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1112 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1113 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1114 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1115 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1116 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1117 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1118 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1119 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1120 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1121 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1122 */
1123__kernel void winograd_filter_transform_4x1_5x1_nchw(
1124 TENSOR4D_DECLARATION(src),
1125 TENSOR3D_DECLARATION(dst))
1126{
1127 winograd_filter_transform_4x4_5x5_nchw(src_ptr,
1128 src_stride_x,
1129 src_step_x,
1130 src_stride_y,
1131 src_step_y,
1132 src_stride_z,
1133 src_step_z,
1134 src_stride_w,
1135 src_step_w,
1136 src_offset_first_element_in_bytes,
1137 dst_ptr,
1138 dst_stride_x,
1139 dst_step_x,
1140 dst_stride_y,
1141 dst_step_y,
1142 dst_stride_z,
1143 dst_step_z,
1144 dst_offset_first_element_in_bytes);
1145}
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001146
1147/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NHWC and the output tile is 4x1
1148 *
1149 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1150 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
1151 *
1152 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1153 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1154 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1155 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1156 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1157 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1158 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1159 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1160 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1161 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1162 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1163 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1164 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1165 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1166 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1167 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1168 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1169 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1170 */
1171__kernel void winograd_filter_transform_4x1_3x1_nhwc(
1172 TENSOR4D_DECLARATION(src),
1173 TENSOR3D_DECLARATION(dst))
1174{
1175 winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
1176 src_stride_x,
1177 src_step_x,
1178 src_stride_y,
1179 src_step_y,
1180 src_stride_z,
1181 src_step_z,
1182 src_stride_w,
1183 src_step_w,
1184 src_offset_first_element_in_bytes,
1185 dst_ptr,
1186 dst_stride_x,
1187 dst_step_x,
1188 dst_stride_y,
1189 dst_step_y,
1190 dst_stride_z,
1191 dst_step_z,
1192 dst_offset_first_element_in_bytes);
1193}
1194
1195/** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NHWC and the output tile is 4x1
1196 *
1197 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1198 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
1199 *
1200 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1201 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1202 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1203 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1204 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1205 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1206 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1207 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1208 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1209 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1210 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1211 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1212 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1213 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1214 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1215 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1216 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1217 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1218 */
1219__kernel void winograd_filter_transform_4x1_5x1_nhwc(
1220 TENSOR4D_DECLARATION(src),
1221 TENSOR3D_DECLARATION(dst))
1222{
1223 winograd_filter_transform_4x4_5x5_nhwc(src_ptr,
1224 src_stride_x,
1225 src_step_x,
1226 src_stride_y,
1227 src_step_y,
1228 src_stride_z,
1229 src_step_z,
1230 src_stride_w,
1231 src_step_w,
1232 src_offset_first_element_in_bytes,
1233 dst_ptr,
1234 dst_stride_x,
1235 dst_step_x,
1236 dst_stride_y,
1237 dst_step_y,
1238 dst_stride_z,
1239 dst_step_z,
1240 dst_offset_first_element_in_bytes);
1241}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001242#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1243
1244#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1245/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x2
1246 *
1247 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1248 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
1249 *
1250 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1251 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1252 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1253 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1254 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1255 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1256 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1257 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1258 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1259 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1260 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1261 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1262 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1263 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1264 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1265 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1266 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1267 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1268 */
1269__kernel void winograd_filter_transform_1x2_1x3_nchw(
1270 TENSOR4D_DECLARATION(src),
1271 TENSOR3D_DECLARATION(dst))
1272{
1273 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
1274 src_stride_x,
1275 src_step_x,
1276 src_stride_y,
1277 src_step_y,
1278 src_stride_z,
1279 src_step_z,
1280 src_stride_w,
1281 src_step_w,
1282 src_offset_first_element_in_bytes,
1283 dst_ptr,
1284 dst_stride_x,
1285 dst_step_x,
1286 dst_stride_y,
1287 dst_step_y,
1288 dst_stride_z,
1289 dst_step_z,
1290 dst_offset_first_element_in_bytes);
1291}
1292
1293/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x4
1294 *
1295 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1296 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
1297 *
1298 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1299 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1300 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1301 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1302 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1303 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1304 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1305 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1306 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1307 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1308 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1309 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1310 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1311 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1312 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1313 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1314 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1315 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1316 */
1317__kernel void winograd_filter_transform_1x4_1x3_nchw(
1318 TENSOR4D_DECLARATION(src),
1319 TENSOR3D_DECLARATION(dst))
1320{
1321 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
1322 src_stride_x,
1323 src_step_x,
1324 src_stride_y,
1325 src_step_y,
1326 src_stride_z,
1327 src_step_z,
1328 src_stride_w,
1329 src_step_w,
1330 src_offset_first_element_in_bytes,
1331 dst_ptr,
1332 dst_stride_x,
1333 dst_step_x,
1334 dst_stride_y,
1335 dst_step_y,
1336 dst_stride_z,
1337 dst_step_z,
1338 dst_offset_first_element_in_bytes);
1339}
1340
1341/** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NCHW and the output tile is 1x4
1342 *
1343 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1344 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
1345 *
1346 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1347 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1348 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1349 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1350 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1351 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1352 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1353 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1354 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1355 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1356 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1357 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1358 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1359 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1360 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1361 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1362 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1363 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1364 */
1365__kernel void winograd_filter_transform_1x4_1x5_nchw(
1366 TENSOR4D_DECLARATION(src),
1367 TENSOR3D_DECLARATION(dst))
1368{
1369 winograd_filter_transform_4x4_5x5_nchw(src_ptr,
1370 src_stride_x,
1371 src_step_x,
1372 src_stride_y,
1373 src_step_y,
1374 src_stride_z,
1375 src_step_z,
1376 src_stride_w,
1377 src_step_w,
1378 src_offset_first_element_in_bytes,
1379 dst_ptr,
1380 dst_stride_x,
1381 dst_step_x,
1382 dst_stride_y,
1383 dst_step_y,
1384 dst_stride_z,
1385 dst_step_z,
1386 dst_offset_first_element_in_bytes);
1387}
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001388
1389/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NHWC and the output tile is 1x4
1390 *
1391 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1392 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
1393 *
1394 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1395 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1396 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1397 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1398 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1399 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1400 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1401 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1402 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1403 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1404 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1405 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1406 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1407 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1408 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1409 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1410 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1411 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1412 */
1413__kernel void winograd_filter_transform_1x4_1x3_nhwc(
1414 TENSOR4D_DECLARATION(src),
1415 TENSOR3D_DECLARATION(dst))
1416{
1417 winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
1418 src_stride_x,
1419 src_step_x,
1420 src_stride_y,
1421 src_step_y,
1422 src_stride_z,
1423 src_step_z,
1424 src_stride_w,
1425 src_step_w,
1426 src_offset_first_element_in_bytes,
1427 dst_ptr,
1428 dst_stride_x,
1429 dst_step_x,
1430 dst_stride_y,
1431 dst_step_y,
1432 dst_stride_z,
1433 dst_step_z,
1434 dst_offset_first_element_in_bytes);
1435}
1436
1437/** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NHWC and the output tile is 1x4
1438 *
1439 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1440 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
1441 *
1442 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1443 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1444 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1445 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1446 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1447 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1448 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1449 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1450 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1451 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1452 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1453 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1454 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1455 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1456 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1457 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1458 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1459 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1460 */
1461__kernel void winograd_filter_transform_1x4_1x5_nhwc(
1462 TENSOR4D_DECLARATION(src),
1463 TENSOR3D_DECLARATION(dst))
1464{
1465 winograd_filter_transform_4x4_5x5_nhwc(src_ptr,
1466 src_stride_x,
1467 src_step_x,
1468 src_stride_y,
1469 src_step_y,
1470 src_stride_z,
1471 src_step_z,
1472 src_stride_w,
1473 src_step_w,
1474 src_offset_first_element_in_bytes,
1475 dst_ptr,
1476 dst_stride_x,
1477 dst_step_x,
1478 dst_stride_y,
1479 dst_step_y,
1480 dst_stride_z,
1481 dst_step_z,
1482 dst_offset_first_element_in_bytes);
1483}
Giorgio Arena149fdf32018-07-04 17:03:33 +01001484#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)