blob: 5f528d4b0e58106329b7b86e6f01a5b862b4b21e [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26#if defined(SRC_DIM_Z)
27
28/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 2x2/2x1/1x2
29 *
30 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
31 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
32 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
33 *
34 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
35 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
36 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
37 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
38 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
39 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
40 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
41 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
42 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
43 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
44 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
45 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
46 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
47 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
48 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
49 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
50 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
51 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
52 */
53__kernel void winograd_filter_transform_2x2_3x3_nchw(
54 TENSOR4D_DECLARATION(src),
55 TENSOR3D_DECLARATION(dst))
56{
57 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
58
59 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
60
61 // Load the values from the input tensor
62#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
63 float3 w0 = vload3(0, (__global float *)(src_addr));
64#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
65 float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
66 *((__global float *)(src_addr + 1 * src_stride_y)),
67 *((__global float *)(src_addr + 2 * src_stride_y)));
68#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +010069 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
70 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
71 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010072#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
73
74 // Row 0
75 float4 out0 = 0.0f;
76 out0.s0 = (w0.s0);
77 out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
78 out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
79 out0.s3 = (w0.s2);
80
81#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
82 // Row 1
83 float4 out1 = 0.0f;
84 out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
85 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
86 out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
87 out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
88
89 // Row 2
90 float4 out2 = 0.0f;
91 out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
92 out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
93 out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
94 out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
95
96 // Row 3
97 float4 out3 = 0.0f;
98 out3.s0 = (w2.s0);
99 out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
100 out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
101 out3.s3 = (w2.s2);
102#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
103
104 int z = get_global_id(2);
105 int x0 = z / SRC_DIM_Z; // idx filter
106 int y0 = z % SRC_DIM_Z; // idx channel
107
108 // Get output address
109 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
110
111 // Store the values across the channels
112 // 16 channels for 3x3 kernels
113 // 4 channels for 3x1 or 1x3 kernels
114 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
115 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
116 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
117 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
118
119#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
120 *(__global float *)(dst_addr + 4 * dst_stride_z) = out1.s0;
121 *(__global float *)(dst_addr + 5 * dst_stride_z) = out1.s1;
122 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s2;
123 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s3;
124 *(__global float *)(dst_addr + 8 * dst_stride_z) = out2.s0;
125 *(__global float *)(dst_addr + 9 * dst_stride_z) = out2.s1;
126 *(__global float *)(dst_addr + 10 * dst_stride_z) = out2.s2;
127 *(__global float *)(dst_addr + 11 * dst_stride_z) = out2.s3;
128 *(__global float *)(dst_addr + 12 * dst_stride_z) = out3.s0;
129 *(__global float *)(dst_addr + 13 * dst_stride_z) = out3.s1;
130 *(__global float *)(dst_addr + 14 * dst_stride_z) = out3.s2;
131 *(__global float *)(dst_addr + 15 * dst_stride_z) = out3.s3;
132#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
133}
134
135/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 4x4/4x1/1x4
136 *
137 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
138 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
139 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
140 *
141 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
142 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
143 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
144 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
145 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
146 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
147 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
148 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
149 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
150 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
151 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
152 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
153 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
154 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
155 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
156 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
157 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
158 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
159 */
160__kernel void winograd_filter_transform_4x4_3x3_nchw(
161 TENSOR4D_DECLARATION(src),
162 TENSOR3D_DECLARATION(dst))
163{
164 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
165
166 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
167
168 // Load the values from the input tensor
169#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
170 float3 w0 = vload3(0, (__global float *)(src_addr));
171#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
172 float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
173 *((__global float *)(src_addr + 1 * src_stride_y)),
174 *((__global float *)(src_addr + 2 * src_stride_y)));
175#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100176 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
177 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
178 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100179#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
180
181 // Row 0
182 float8 out0 = 0.0f;
183 out0.s0 = (w0.s0) / 16.f;
184 out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
185 out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
186 out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
187 out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
188 out0.s5 = (w0.s2) / 4.f;
189
190#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
191 // Row 1
192 float8 out1 = 0.0f;
193 out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
194 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
195 out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
196 out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
197 out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
198 out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
199
200 // Row 2
201 float8 out2 = 0.0f;
202 out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
203 out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
204 out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
205 out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
206 out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
207 out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
208
209 // Row 3
210 float8 out3 = 0.0f;
211 out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
212 out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
213 out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
214 out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
215 out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
216 out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
217
218 // Row 4
219 float8 out4 = 0.0f;
220 out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
221 out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
222 out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
223 out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
224 out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
225 out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
226
227 // Row 5
228 float8 out5 = 0.0f;
229 out5.s0 = (w2.s0) / 4.f;
230 out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
231 out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
232 out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
233 out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
234 out5.s5 = (w2.s2);
235#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
236
237 int z = get_global_id(2);
238 int x0 = z / SRC_DIM_Z; // idx filter
239 int y0 = z % SRC_DIM_Z; // idx channel
240
241 // Get output address
242 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
243
244 // Store the values across the channels
245 // 36 channels for 3x3 kernels
246 // 6 channels for 3x1 or 1x3 kernels
247 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
248 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
249 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
250 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
251 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
252 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
253
254#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
255 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s0;
256 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s1;
257 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s2;
258 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s3;
259 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s4;
260 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s5;
261 *(__global float *)(dst_addr + 12 * dst_stride_z) = out2.s0;
262 *(__global float *)(dst_addr + 13 * dst_stride_z) = out2.s1;
263 *(__global float *)(dst_addr + 14 * dst_stride_z) = out2.s2;
264 *(__global float *)(dst_addr + 15 * dst_stride_z) = out2.s3;
265 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s4;
266 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s5;
267 *(__global float *)(dst_addr + 18 * dst_stride_z) = out3.s0;
268 *(__global float *)(dst_addr + 19 * dst_stride_z) = out3.s1;
269 *(__global float *)(dst_addr + 20 * dst_stride_z) = out3.s2;
270 *(__global float *)(dst_addr + 21 * dst_stride_z) = out3.s3;
271 *(__global float *)(dst_addr + 22 * dst_stride_z) = out3.s4;
272 *(__global float *)(dst_addr + 23 * dst_stride_z) = out3.s5;
273 *(__global float *)(dst_addr + 24 * dst_stride_z) = out4.s0;
274 *(__global float *)(dst_addr + 25 * dst_stride_z) = out4.s1;
275 *(__global float *)(dst_addr + 26 * dst_stride_z) = out4.s2;
276 *(__global float *)(dst_addr + 27 * dst_stride_z) = out4.s3;
277 *(__global float *)(dst_addr + 28 * dst_stride_z) = out4.s4;
278 *(__global float *)(dst_addr + 29 * dst_stride_z) = out4.s5;
279 *(__global float *)(dst_addr + 30 * dst_stride_z) = out5.s0;
280 *(__global float *)(dst_addr + 31 * dst_stride_z) = out5.s1;
281 *(__global float *)(dst_addr + 32 * dst_stride_z) = out5.s2;
282 *(__global float *)(dst_addr + 33 * dst_stride_z) = out5.s3;
283 *(__global float *)(dst_addr + 34 * dst_stride_z) = out5.s4;
284 *(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
285#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
286}
287
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100288/** This OpenCL kernel performs Winograd filter transform 3x3 when the data layout is NHWC and the output tile is 4x4
289 *
290 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
291 *
292 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
293 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
294 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
295 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
296 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
297 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
298 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
299 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
300 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
301 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
302 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
303 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
304 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
305 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
306 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
307 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
308 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
309 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
310 */
311__kernel void winograd_filter_transform_4x4_3x3_nhwc(
312 TENSOR4D_DECLARATION(src),
313 TENSOR3D_DECLARATION(dst))
314{
315 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
316
317 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
318
319 // Load the values from the input tensor
320 float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
321 float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
322 float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
323 float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
324 float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
325 float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
326 float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
327 float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
328 float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
329
330 // Transform the 3x3 tile in a 6x6 tile
331 float out00, out01, out02, out03, out04, out05;
332 float out10, out11, out12, out13, out14, out15;
333 float out20, out21, out22, out23, out24, out25;
334 float out30, out31, out32, out33, out34, out35;
335 float out40, out41, out42, out43, out44, out45;
336 float out50, out51, out52, out53, out54, out55;
337
338 out00 = out01 = out02 = out03 = out04 = out05 = 0.f;
339 out10 = out11 = out12 = out13 = out14 = out15 = 0.f;
340 out20 = out21 = out22 = out23 = out24 = out25 = 0.f;
341 out30 = out31 = out32 = out33 = out34 = out35 = 0.f;
342 out40 = out41 = out42 = out43 = out44 = out45 = 0.f;
343 out50 = out51 = out52 = out53 = out54 = out55 = 0.f;
344
345 // Row 0
346 out00 = (w00) / 16.f;
347 out01 = (-w00 - w01 - w02) / 24.f;
348 out02 = (-w00 + w01 - w02) / 24.f;
349 out03 = (w00 + 2.f * w01 + 4.f * w02) / 96.f;
350 out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f;
351 out05 = (w02) / 4.f;
352
353 // Row 1
354 out10 = (-w00 - w10 - w20) / 24.f;
355 out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f;
356 out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f;
357 out13 = (-w00 - w10 - w20 + 2.f * (-w01 - w11 - w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
358 out14 = (-w00 - w10 - w20 + 2.f * (w01 + w11 + w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
359 out15 = (-w02 - w12 - w22) / 6.f;
360
361 // Row 2
362 out20 = (-w00 + w10 - w20) / 24.f;
363 out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f;
364 out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f;
365 out23 = (-w00 + w10 - w20 + 2.f * (-w01 + w11 - w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
366 out24 = (-w00 + w10 - w20 + 2.f * (w01 - w11 + w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
367 out25 = (-w02 + w12 - w22) / 6.f;
368
369 // Row 3
370 out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f;
371 out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
372 out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
373 out33 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (w01 + 2.f * w11 + 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
374 out34 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (-w01 - 2.f * w11 - 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
375 out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f;
376
377 // Row 4
378 out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f;
379 out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
380 out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
381 out43 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (w01 - 2.f * w11 + 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
382 out44 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (-w01 + 2.f * w11 - 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
383 out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f;
384
385 // Row 5
386 out50 = (w20) / 4.f;
387 out51 = (-w20 - w21 - w22) / 6.f;
388 out52 = (-w20 + w21 - w22) / 6.f;
389 out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f;
390 out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f;
391 out55 = (w22);
392
393 int x0 = get_global_id(2); // idx filter
394 int y0 = get_global_id(0); // idx channel
395
396 // Get output address
397 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
398
399 // Store the values across the channels
400 *(__global float *)(dst_addr + 0 * dst_stride_z) = out00;
401 *(__global float *)(dst_addr + 1 * dst_stride_z) = out01;
402 *(__global float *)(dst_addr + 2 * dst_stride_z) = out02;
403 *(__global float *)(dst_addr + 3 * dst_stride_z) = out03;
404 *(__global float *)(dst_addr + 4 * dst_stride_z) = out04;
405 *(__global float *)(dst_addr + 5 * dst_stride_z) = out05;
406 *(__global float *)(dst_addr + 6 * dst_stride_z) = out10;
407 *(__global float *)(dst_addr + 7 * dst_stride_z) = out11;
408 *(__global float *)(dst_addr + 8 * dst_stride_z) = out12;
409 *(__global float *)(dst_addr + 9 * dst_stride_z) = out13;
410 *(__global float *)(dst_addr + 10 * dst_stride_z) = out14;
411 *(__global float *)(dst_addr + 11 * dst_stride_z) = out15;
412 *(__global float *)(dst_addr + 12 * dst_stride_z) = out20;
413 *(__global float *)(dst_addr + 13 * dst_stride_z) = out21;
414 *(__global float *)(dst_addr + 14 * dst_stride_z) = out22;
415 *(__global float *)(dst_addr + 15 * dst_stride_z) = out23;
416 *(__global float *)(dst_addr + 16 * dst_stride_z) = out24;
417 *(__global float *)(dst_addr + 17 * dst_stride_z) = out25;
418 *(__global float *)(dst_addr + 18 * dst_stride_z) = out30;
419 *(__global float *)(dst_addr + 19 * dst_stride_z) = out31;
420 *(__global float *)(dst_addr + 20 * dst_stride_z) = out32;
421 *(__global float *)(dst_addr + 21 * dst_stride_z) = out33;
422 *(__global float *)(dst_addr + 22 * dst_stride_z) = out34;
423 *(__global float *)(dst_addr + 23 * dst_stride_z) = out35;
424 *(__global float *)(dst_addr + 24 * dst_stride_z) = out40;
425 *(__global float *)(dst_addr + 25 * dst_stride_z) = out41;
426 *(__global float *)(dst_addr + 26 * dst_stride_z) = out42;
427 *(__global float *)(dst_addr + 27 * dst_stride_z) = out43;
428 *(__global float *)(dst_addr + 28 * dst_stride_z) = out44;
429 *(__global float *)(dst_addr + 29 * dst_stride_z) = out45;
430 *(__global float *)(dst_addr + 30 * dst_stride_z) = out50;
431 *(__global float *)(dst_addr + 31 * dst_stride_z) = out51;
432 *(__global float *)(dst_addr + 32 * dst_stride_z) = out52;
433 *(__global float *)(dst_addr + 33 * dst_stride_z) = out53;
434 *(__global float *)(dst_addr + 34 * dst_stride_z) = out54;
435 *(__global float *)(dst_addr + 35 * dst_stride_z) = out55;
436}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100437/** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NCHW and the output tile is 4x4/4x1 or 1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100438 *
439 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
440 *
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100441 * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
442 * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
443 *
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100444 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
445 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
446 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
447 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
448 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
449 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
450 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
451 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
452 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
453 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
454 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
455 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
456 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
457 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
458 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
459 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
460 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
461 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
462 */
463__kernel void winograd_filter_transform_4x4_5x5_nchw(
464 TENSOR4D_DECLARATION(src),
465 TENSOR3D_DECLARATION(dst))
466{
467 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
468
469 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
470
471 // Load the values from the input tensor
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100472#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
473 float4 w00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
474 float w01 = *((__global float *)(src_addr + 0 * src_stride_y) + 4);
475#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
476 float4 w00 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
477 *((__global float *)(src_addr + 1 * src_stride_y)),
478 *((__global float *)(src_addr + 2 * src_stride_y)),
479 *((__global float *)(src_addr + 3 * src_stride_y)));
480 float w01 = *((__global float *)(src_addr + 4 * src_stride_y));
481#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
482 float4 w00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
483 float w01 = *((__global float *)(src_addr + 0 * src_stride_y) + 4);
484 float4 w10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
485 float w11 = *((__global float *)(src_addr + 1 * src_stride_y) + 4);
486 float4 w20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
487 float w21 = *((__global float *)(src_addr + 2 * src_stride_y) + 4);
488 float4 w30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
489 float w31 = *((__global float *)(src_addr + 3 * src_stride_y) + 4);
490 float4 w40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
491 float w41 = *((__global float *)(src_addr + 4 * src_stride_y) + 4);
492#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100493
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100494 // Transform the input tile
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100495
496 // Row 0
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100497 float8 out0 = 0.0f;
498 out0.s0 = w00.s0;
499 out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
500 out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
501 out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
502 out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
503 out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
504 out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
505 out0.s7 = w01;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100506
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100507#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100508 // Row 1
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100509 float8 out1 = 0.0f;
510 out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
511 out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
512 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
513 out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
514 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
515 out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
516 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
517 out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
518 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
519 out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
520 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
521 out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
522 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
523 out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100524
525 // Row 2
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100526 float8 out2 = 0.0f;
527 out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
528 out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
529 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
530 out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
531 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
532 out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
533 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
534 out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
535 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
536 out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
537 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
538 out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
539 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
540 out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100541
542 // Row 3
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100543 float8 out3 = 0.0f;
544 out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
545 out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
546 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
547 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
548 out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
549 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
550 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
551 out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
552 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
553 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
554 out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
555 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
556 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
557 out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
558 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
559 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
560 out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
561 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
562 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
563 out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100564
565 // Row 4
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100566 float8 out4 = 0.0f;
567 out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
568 out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
569 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
570 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
571 out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
572 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
573 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
574 out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
575 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
576 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
577 out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
578 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
579 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
580 out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
581 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
582 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
583 out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
584 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
585 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
586 out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100587
588 // Row 5
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100589 float8 out5 = 0.0f;
590 out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
591 out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
592 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
593 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
594 out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
595 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
596 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
597 out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
598 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
599 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
600 out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
601 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
602 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
603 out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
604 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
605 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
606 out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
607 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
608 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
609 out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100610
611 // Row 6
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100612 float8 out6 = 0.0f;
613 out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
614 out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
615 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
616 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
617 out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
618 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
619 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
620 out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
621 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
622 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
623 out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
624 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
625 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
626 out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
627 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
628 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
629 out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
630 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
631 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
632 out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100633
634 // Row 7
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100635 float8 out7 = 0.0f;
636 out7.s0 = w40.s0;
637 out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
638 out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
639 out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
640 out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
641 out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
642 out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
643 out7.s7 = w41;
644#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100645
646 int z = get_global_id(2);
647 int x0 = z / SRC_DIM_Z; // idx filter
648 int y0 = z % SRC_DIM_Z; // idx channel
649
650 // Get output address
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100651 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100652
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100653 // Store the values across the channels
654 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
655 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
656 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
657 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
658 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
659 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
660 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
661 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
662
663#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100664 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
665 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
666 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
667 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
668 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
669 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
670 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
671 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
672 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
673 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
674 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
675 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
676 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
677 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
678 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
679 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
680 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
681 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
682 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
683 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
684 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
685 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
686 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
687 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
688 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
689 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
690 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
691 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
692 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
693 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
694 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
695 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
696 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
697 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
698 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
699 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
700 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
701 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
702 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
703 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
704 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
705 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
706 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
707 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
708 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
709 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
710 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
711 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
712 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
713 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
714 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
715 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
716 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
717 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
718 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
719 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100720#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100721}
722
723/** This OpenCL kernel performs Winograd filter transform 5x5 when the data layout is NHWC and the output tile is 4x4
724 *
725 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
726 *
727 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
728 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
729 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
730 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
731 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
732 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
733 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
734 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
735 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
736 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
737 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
738 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
739 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
740 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
741 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
742 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
743 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
744 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
745 */
746__kernel void winograd_filter_transform_4x4_5x5_nhwc(
747 TENSOR4D_DECLARATION(src),
748 TENSOR3D_DECLARATION(dst))
749{
750 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
751
752 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(float) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
753
754 // Load the values from the input tensor
755 float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
756 float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
757 float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
758 float w03 = *((__global float *)(src_addr + 0 * src_stride_z + 3 * src_stride_y));
759 float w04 = *((__global float *)(src_addr + 0 * src_stride_z + 4 * src_stride_y));
760 float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
761 float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
762 float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
763 float w13 = *((__global float *)(src_addr + 1 * src_stride_z + 3 * src_stride_y));
764 float w14 = *((__global float *)(src_addr + 1 * src_stride_z + 4 * src_stride_y));
765 float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
766 float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
767 float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
768 float w23 = *((__global float *)(src_addr + 2 * src_stride_z + 3 * src_stride_y));
769 float w24 = *((__global float *)(src_addr + 2 * src_stride_z + 4 * src_stride_y));
770 float w30 = *((__global float *)(src_addr + 3 * src_stride_z + 0 * src_stride_y));
771 float w31 = *((__global float *)(src_addr + 3 * src_stride_z + 1 * src_stride_y));
772 float w32 = *((__global float *)(src_addr + 3 * src_stride_z + 2 * src_stride_y));
773 float w33 = *((__global float *)(src_addr + 3 * src_stride_z + 3 * src_stride_y));
774 float w34 = *((__global float *)(src_addr + 3 * src_stride_z + 4 * src_stride_y));
775 float w40 = *((__global float *)(src_addr + 4 * src_stride_z + 0 * src_stride_y));
776 float w41 = *((__global float *)(src_addr + 4 * src_stride_z + 1 * src_stride_y));
777 float w42 = *((__global float *)(src_addr + 4 * src_stride_z + 2 * src_stride_y));
778 float w43 = *((__global float *)(src_addr + 4 * src_stride_z + 3 * src_stride_y));
779 float w44 = *((__global float *)(src_addr + 4 * src_stride_z + 4 * src_stride_y));
780
781 // Transform the 3x3 tile in a 8x8 tile
782 float8 out0 = 0.0f;
783 float8 out1 = 0.0f;
784 float8 out2 = 0.0f;
785 float8 out3 = 0.0f;
786 float8 out4 = 0.0f;
787 float8 out5 = 0.0f;
788 float8 out6 = 0.0f;
789 float8 out7 = 0.0f;
790
791 // Row 0
792 out0.s0 = w00;
793 out0.s1 = -2.f * (w00 + w01 + w02 + w03 + w04) / 9.f;
794 out0.s2 = -2.f * (w00 - w01 + w02 - w03 + w04) / 9.f;
795 out0.s3 = (w00 + 2.f * w01 + 4.f * w02 + 8.f * w03 + 16.f * w04) / 90.f;
796 out0.s4 = (w00 - 2.f * w01 + 4.f * w02 - 8.f * w03 + 16.f * w04) / 90.f;
797 out0.s5 = (16.f * w00 + 8.f * w01 + 4.f * w02 + 2.f * w03 + w04) / 180.f;
798 out0.s6 = (16.f * w00 - 8.f * w01 + 4.f * w02 - 2.f * w03 + w04) / 180.f;
799 out0.s7 = w04;
800
801 // Row 1
802 out1.s0 = -2.f * (w00 + w10 + w20 + w30 + w40) / 9.f;
803 out1.s1 = 4.f * ((w00 + w10 + w20 + w30 + w40) + (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) + (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
804 out1.s2 = 4.f * ((w00 + w10 + w20 + w30 + w40) - (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) - (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
805 out1.s3 = -((w00 + w10 + w20 + w30 + w40) + 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
806 (w04 + w14 + w24 + w34 + w44)) / 405.f;
807 out1.s4 = -((w00 + w10 + w20 + w30 + w40) - 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
808 (w04 + w14 + w24 + w34 + w44)) / 405.f;
809 out1.s5 = -(16.f * (w00 + w10 + w20 + w30 + w40) + 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 2.f * (w03 + w13 + w23 + w33 + w43) +
810 (w04 + w14 + w24 + w34 + w44)) / 810.f;
811 out1.s6 = -(16.f * (w00 + w10 + w20 + w30 + w40) - 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 2.f * (w03 + w13 + w23 + w33 + w43) +
812 (w04 + w14 + w24 + w34 + w44)) / 810.f;
813 out1.s7 = -2.f * (w04 + w14 + w24 + w34 + w44) / 9.f;
814
815 // Row 2
816 out2.s0 = -2.f * (w00 - w10 + w20 - w30 + w40) / 9.f;
817 out2.s1 = 4.f * ((w00 - w10 + w20 - w30 + w40) + (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) + (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
818 out2.s2 = 4.f * ((w00 - w10 + w20 - w30 + w40) - (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) - (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
819 out2.s3 = -((w00 - w10 + w20 - w30 + w40) + 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
820 (w04 - w14 + w24 - w34 + w44)) / 405.f;
821 out2.s4 = -((w00 - w10 + w20 - w30 + w40) - 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
822 (w04 - w14 + w24 - w34 + w44)) / 405.f;
823 out2.s5 = -(16.f * (w00 - w10 + w20 - w30 + w40) + 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 2.f * (w03 - w13 + w23 - w33 + w43) +
824 (w04 - w14 + w24 - w34 + w44)) / 810.f;
825 out2.s6 = -(16.f * (w00 - w10 + w20 - w30 + w40) - 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 2.f * (w03 - w13 + w23 - w33 + w43) +
826 (w04 - w14 + w24 - w34 + w44)) / 810.f;
827 out2.s7 = -2.f * (w04 - w14 + w24 - w34 + w44) / 9.f;
828
829 // Row 3
830 out3.s0 = (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) / 90.f;
831 out3.s1 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) +
832 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
833 out3.s2 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) -
834 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
835 out3.s3 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 8.f
836 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
837 out3.s4 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 8.f
838 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
839 out3.s5 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
840 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
841 out3.s6 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
842 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
843 out3.s7 = (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44) / 90.f;
844
845 // Row 4
846 out4.s0 = (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) / 90.f;
847 out4.s1 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) +
848 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
849 out4.s2 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) -
850 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
851 out4.s3 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 8.f
852 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
853 out4.s4 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 8.f
854 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
855 out4.s5 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
856 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
857 out4.s6 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
858 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
859 out4.s7 = (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44) / 90.f;
860
861 // Row 5
862 out5.s0 = (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) / 180.f;
863 out5.s1 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) +
864 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
865 out5.s2 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) -
866 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
867 out5.s3 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 8.f
868 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
869 out5.s4 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 8.f
870 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
871 out5.s5 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
872 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
873 out5.s6 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
874 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
875 out5.s7 = (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44) / 180.f;
876
877 // Row 6
878 out6.s0 = (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) / 180.f;
879 out6.s1 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) +
880 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
881 out6.s2 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) -
882 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
883 out6.s3 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 8.f
884 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
885 out6.s4 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 8.f
886 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
887 out6.s5 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
888 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
889 out6.s6 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
890 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
891 out6.s7 = (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44) / 180.f;
892
893 // Row 7
894 out7.s0 = w40;
895 out7.s1 = -2.f * (w40 + w41 + w42 + w43 + w44) / 9.f;
896 out7.s2 = -2.f * (w40 - w41 + w42 - w43 + w44) / 9.f;
897 out7.s3 = (w40 + 2.f * w41 + 4.f * w42 + 8.f * w43 + 16.f * w44) / 90.f;
898 out7.s4 = (w40 - 2.f * w41 + 4.f * w42 - 8.f * w43 + 16.f * w44) / 90.f;
899 out7.s5 = (16.f * w40 + 8.f * w41 + 4.f * w42 + 2.f * w43 + w44) / 180.f;
900 out7.s6 = (16.f * w40 - 8.f * w41 + 4.f * w42 - 2.f * w43 + w44) / 180.f;
901 out7.s7 = w44;
902
903 int x0 = get_global_id(2); // idx filter
904 int y0 = get_global_id(0); // idx channel
905
906 // Get output address
907 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
908
909 // Store the 64 values across the 64 channels
910 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
911 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
912 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
913 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
914 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
915 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
916 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
917 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
918 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
919 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
920 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
921 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
922 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
923 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
924 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
925 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
926 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
927 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
928 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
929 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
930 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
931 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
932 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
933 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
934 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
935 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
936 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
937 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
938 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
939 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
940 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
941 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
942 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
943 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
944 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
945 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
946 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
947 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
948 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
949 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
950 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
951 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
952 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
953 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
954 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
955 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
956 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
957 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
958 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
959 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
960 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
961 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
962 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
963 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
964 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
965 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
966 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
967 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
968 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
969 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
970 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
971 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
972 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
973 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
974}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100975#endif // defined(SRC_DIM_Z)
976
977#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
978/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 2x1
979 *
980 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
981 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
982 *
983 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
984 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
985 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
986 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
987 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
988 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
989 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
990 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
991 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
992 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
993 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
994 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
995 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
996 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
997 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
998 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
999 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1000 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1001 */
1002__kernel void winograd_filter_transform_2x1_3x1_nchw(
1003 TENSOR4D_DECLARATION(src),
1004 TENSOR3D_DECLARATION(dst))
1005{
1006 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
1007 src_stride_x,
1008 src_step_x,
1009 src_stride_y,
1010 src_step_y,
1011 src_stride_z,
1012 src_step_z,
1013 src_stride_w,
1014 src_step_w,
1015 src_offset_first_element_in_bytes,
1016 dst_ptr,
1017 dst_stride_x,
1018 dst_step_x,
1019 dst_stride_y,
1020 dst_step_y,
1021 dst_stride_z,
1022 dst_step_z,
1023 dst_offset_first_element_in_bytes);
1024}
1025
1026/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 4x1
1027 *
1028 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1029 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
1030 *
1031 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1032 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1033 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1034 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1035 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1036 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1037 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1038 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1039 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1040 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1041 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1042 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1043 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1044 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1045 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1046 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1047 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1048 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1049 */
1050__kernel void winograd_filter_transform_4x1_3x1_nchw(
1051 TENSOR4D_DECLARATION(src),
1052 TENSOR3D_DECLARATION(dst))
1053{
1054 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
1055 src_stride_x,
1056 src_step_x,
1057 src_stride_y,
1058 src_step_y,
1059 src_stride_z,
1060 src_step_z,
1061 src_stride_w,
1062 src_step_w,
1063 src_offset_first_element_in_bytes,
1064 dst_ptr,
1065 dst_stride_x,
1066 dst_step_x,
1067 dst_stride_y,
1068 dst_step_y,
1069 dst_stride_z,
1070 dst_step_z,
1071 dst_offset_first_element_in_bytes);
1072}
1073
1074/** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NCHW and the output tile is 4x1
1075 *
1076 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1077 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
1078 *
1079 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1080 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1081 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1082 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1083 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1084 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1085 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1086 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1087 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1088 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1089 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1090 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1091 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1092 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1093 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1094 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1095 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1096 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1097 */
1098__kernel void winograd_filter_transform_4x1_5x1_nchw(
1099 TENSOR4D_DECLARATION(src),
1100 TENSOR3D_DECLARATION(dst))
1101{
1102 winograd_filter_transform_4x4_5x5_nchw(src_ptr,
1103 src_stride_x,
1104 src_step_x,
1105 src_stride_y,
1106 src_step_y,
1107 src_stride_z,
1108 src_step_z,
1109 src_stride_w,
1110 src_step_w,
1111 src_offset_first_element_in_bytes,
1112 dst_ptr,
1113 dst_stride_x,
1114 dst_step_x,
1115 dst_stride_y,
1116 dst_step_y,
1117 dst_stride_z,
1118 dst_step_z,
1119 dst_offset_first_element_in_bytes);
1120}
1121#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1122
1123#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1124/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x2
1125 *
1126 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1127 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
1128 *
1129 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1130 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1131 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1132 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1133 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1134 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1135 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1136 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1137 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1138 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1139 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1140 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1141 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1142 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1143 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1144 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1145 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1146 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1147 */
1148__kernel void winograd_filter_transform_1x2_1x3_nchw(
1149 TENSOR4D_DECLARATION(src),
1150 TENSOR3D_DECLARATION(dst))
1151{
1152 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
1153 src_stride_x,
1154 src_step_x,
1155 src_stride_y,
1156 src_step_y,
1157 src_stride_z,
1158 src_step_z,
1159 src_stride_w,
1160 src_step_w,
1161 src_offset_first_element_in_bytes,
1162 dst_ptr,
1163 dst_stride_x,
1164 dst_step_x,
1165 dst_stride_y,
1166 dst_step_y,
1167 dst_stride_z,
1168 dst_step_z,
1169 dst_offset_first_element_in_bytes);
1170}
1171
1172/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x4
1173 *
1174 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1175 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
1176 *
1177 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1178 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1179 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1180 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1181 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1182 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1183 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1184 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1185 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1186 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1187 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1188 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1189 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1190 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1191 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1192 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1193 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1194 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1195 */
1196__kernel void winograd_filter_transform_1x4_1x3_nchw(
1197 TENSOR4D_DECLARATION(src),
1198 TENSOR3D_DECLARATION(dst))
1199{
1200 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
1201 src_stride_x,
1202 src_step_x,
1203 src_stride_y,
1204 src_step_y,
1205 src_stride_z,
1206 src_step_z,
1207 src_stride_w,
1208 src_step_w,
1209 src_offset_first_element_in_bytes,
1210 dst_ptr,
1211 dst_stride_x,
1212 dst_step_x,
1213 dst_stride_y,
1214 dst_step_y,
1215 dst_stride_z,
1216 dst_step_z,
1217 dst_offset_first_element_in_bytes);
1218}
1219
1220/** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NCHW and the output tile is 1x4
1221 *
1222 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1223 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
1224 *
1225 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1226 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1227 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1228 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1229 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1230 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1231 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1232 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1233 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1234 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1235 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1236 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1237 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1238 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1239 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1240 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1241 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1242 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1243 */
1244__kernel void winograd_filter_transform_1x4_1x5_nchw(
1245 TENSOR4D_DECLARATION(src),
1246 TENSOR3D_DECLARATION(dst))
1247{
1248 winograd_filter_transform_4x4_5x5_nchw(src_ptr,
1249 src_stride_x,
1250 src_step_x,
1251 src_stride_y,
1252 src_step_y,
1253 src_stride_z,
1254 src_step_z,
1255 src_stride_w,
1256 src_step_w,
1257 src_offset_first_element_in_bytes,
1258 dst_ptr,
1259 dst_stride_x,
1260 dst_step_x,
1261 dst_stride_y,
1262 dst_step_y,
1263 dst_stride_z,
1264 dst_step_z,
1265 dst_offset_first_element_in_bytes);
1266}
1267#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)