blob: 1ee6981a07a9e51adfd91b9cc6424cb5e62e50aa [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26#if defined(SRC_DIM_Z)
27
28/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 2x2/2x1/1x2
29 *
30 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
31 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
32 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
33 *
34 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
35 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
36 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
37 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
38 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
39 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
40 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
41 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
42 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
43 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
44 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
45 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
46 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
47 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
48 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
49 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
50 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
51 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
52 */
53__kernel void winograd_filter_transform_2x2_3x3_nchw(
54 TENSOR4D_DECLARATION(src),
55 TENSOR3D_DECLARATION(dst))
56{
57 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
58
59 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
60
61 // Load the values from the input tensor
62#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
63 float3 w0 = vload3(0, (__global float *)(src_addr));
64#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
65 float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
66 *((__global float *)(src_addr + 1 * src_stride_y)),
67 *((__global float *)(src_addr + 2 * src_stride_y)));
68#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
69 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
70 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
71 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
72#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
73
74 // Row 0
75 float4 out0 = 0.0f;
76 out0.s0 = (w0.s0);
77 out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
78 out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
79 out0.s3 = (w0.s2);
80
81#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
82 // Row 1
83 float4 out1 = 0.0f;
84 out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
85 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
86 out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
87 out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
88
89 // Row 2
90 float4 out2 = 0.0f;
91 out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
92 out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
93 out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
94 out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
95
96 // Row 3
97 float4 out3 = 0.0f;
98 out3.s0 = (w2.s0);
99 out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
100 out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
101 out3.s3 = (w2.s2);
102#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
103
104 int z = get_global_id(2);
105 int x0 = z / SRC_DIM_Z; // idx filter
106 int y0 = z % SRC_DIM_Z; // idx channel
107
108 // Get output address
109 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
110
111 // Store the values across the channels
112 // 16 channels for 3x3 kernels
113 // 4 channels for 3x1 or 1x3 kernels
114 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
115 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
116 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
117 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
118
119#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
120 *(__global float *)(dst_addr + 4 * dst_stride_z) = out1.s0;
121 *(__global float *)(dst_addr + 5 * dst_stride_z) = out1.s1;
122 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s2;
123 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s3;
124 *(__global float *)(dst_addr + 8 * dst_stride_z) = out2.s0;
125 *(__global float *)(dst_addr + 9 * dst_stride_z) = out2.s1;
126 *(__global float *)(dst_addr + 10 * dst_stride_z) = out2.s2;
127 *(__global float *)(dst_addr + 11 * dst_stride_z) = out2.s3;
128 *(__global float *)(dst_addr + 12 * dst_stride_z) = out3.s0;
129 *(__global float *)(dst_addr + 13 * dst_stride_z) = out3.s1;
130 *(__global float *)(dst_addr + 14 * dst_stride_z) = out3.s2;
131 *(__global float *)(dst_addr + 15 * dst_stride_z) = out3.s3;
132#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
133}
134
135/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 4x4/4x1/1x4
136 *
137 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
138 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
139 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
140 *
141 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
142 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
143 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
144 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
145 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
146 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
147 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
148 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
149 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
150 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
151 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
152 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
153 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
154 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
155 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
156 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
157 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
158 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
159 */
160__kernel void winograd_filter_transform_4x4_3x3_nchw(
161 TENSOR4D_DECLARATION(src),
162 TENSOR3D_DECLARATION(dst))
163{
164 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
165
166 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
167
168 // Load the values from the input tensor
169#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
170 float3 w0 = vload3(0, (__global float *)(src_addr));
171#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
172 float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
173 *((__global float *)(src_addr + 1 * src_stride_y)),
174 *((__global float *)(src_addr + 2 * src_stride_y)));
175#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
176 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
177 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
178 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
179#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
180
181 // Row 0
182 float8 out0 = 0.0f;
183 out0.s0 = (w0.s0) / 16.f;
184 out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
185 out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
186 out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
187 out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
188 out0.s5 = (w0.s2) / 4.f;
189
190#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
191 // Row 1
192 float8 out1 = 0.0f;
193 out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
194 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
195 out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
196 out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
197 out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
198 out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
199
200 // Row 2
201 float8 out2 = 0.0f;
202 out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
203 out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
204 out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
205 out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
206 out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
207 out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
208
209 // Row 3
210 float8 out3 = 0.0f;
211 out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
212 out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
213 out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
214 out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
215 out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
216 out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
217
218 // Row 4
219 float8 out4 = 0.0f;
220 out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
221 out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
222 out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
223 out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
224 out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
225 out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
226
227 // Row 5
228 float8 out5 = 0.0f;
229 out5.s0 = (w2.s0) / 4.f;
230 out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
231 out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
232 out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
233 out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
234 out5.s5 = (w2.s2);
235#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
236
237 int z = get_global_id(2);
238 int x0 = z / SRC_DIM_Z; // idx filter
239 int y0 = z % SRC_DIM_Z; // idx channel
240
241 // Get output address
242 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
243
244 // Store the values across the channels
245 // 36 channels for 3x3 kernels
246 // 6 channels for 3x1 or 1x3 kernels
247 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
248 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
249 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
250 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
251 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
252 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
253
254#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
255 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s0;
256 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s1;
257 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s2;
258 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s3;
259 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s4;
260 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s5;
261 *(__global float *)(dst_addr + 12 * dst_stride_z) = out2.s0;
262 *(__global float *)(dst_addr + 13 * dst_stride_z) = out2.s1;
263 *(__global float *)(dst_addr + 14 * dst_stride_z) = out2.s2;
264 *(__global float *)(dst_addr + 15 * dst_stride_z) = out2.s3;
265 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s4;
266 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s5;
267 *(__global float *)(dst_addr + 18 * dst_stride_z) = out3.s0;
268 *(__global float *)(dst_addr + 19 * dst_stride_z) = out3.s1;
269 *(__global float *)(dst_addr + 20 * dst_stride_z) = out3.s2;
270 *(__global float *)(dst_addr + 21 * dst_stride_z) = out3.s3;
271 *(__global float *)(dst_addr + 22 * dst_stride_z) = out3.s4;
272 *(__global float *)(dst_addr + 23 * dst_stride_z) = out3.s5;
273 *(__global float *)(dst_addr + 24 * dst_stride_z) = out4.s0;
274 *(__global float *)(dst_addr + 25 * dst_stride_z) = out4.s1;
275 *(__global float *)(dst_addr + 26 * dst_stride_z) = out4.s2;
276 *(__global float *)(dst_addr + 27 * dst_stride_z) = out4.s3;
277 *(__global float *)(dst_addr + 28 * dst_stride_z) = out4.s4;
278 *(__global float *)(dst_addr + 29 * dst_stride_z) = out4.s5;
279 *(__global float *)(dst_addr + 30 * dst_stride_z) = out5.s0;
280 *(__global float *)(dst_addr + 31 * dst_stride_z) = out5.s1;
281 *(__global float *)(dst_addr + 32 * dst_stride_z) = out5.s2;
282 *(__global float *)(dst_addr + 33 * dst_stride_z) = out5.s3;
283 *(__global float *)(dst_addr + 34 * dst_stride_z) = out5.s4;
284 *(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
285#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
286}
287
288#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
289/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 2x1
290 *
291 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
292 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
293 *
294 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
295 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
296 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
297 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
298 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
299 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
300 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
301 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
302 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
303 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
304 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
305 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
306 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
307 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
308 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
309 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
310 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
311 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
312 */
313__kernel void winograd_filter_transform_2x1_3x1_nchw(
314 TENSOR4D_DECLARATION(src),
315 TENSOR3D_DECLARATION(dst))
316{
317 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
318 src_stride_x,
319 src_step_x,
320 src_stride_y,
321 src_step_y,
322 src_stride_z,
323 src_step_z,
324 src_stride_w,
325 src_step_w,
326 src_offset_first_element_in_bytes,
327 dst_ptr,
328 dst_stride_x,
329 dst_step_x,
330 dst_stride_y,
331 dst_step_y,
332 dst_stride_z,
333 dst_step_z,
334 dst_offset_first_element_in_bytes);
335}
336
337/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 4x1
338 *
339 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
340 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
341 *
342 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
343 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
344 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
345 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
346 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
347 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
348 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
349 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
350 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
351 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
352 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
353 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
354 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
355 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
356 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
357 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
358 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
359 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
360 */
361__kernel void winograd_filter_transform_4x1_3x1_nchw(
362 TENSOR4D_DECLARATION(src),
363 TENSOR3D_DECLARATION(dst))
364{
365 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
366 src_stride_x,
367 src_step_x,
368 src_stride_y,
369 src_step_y,
370 src_stride_z,
371 src_step_z,
372 src_stride_w,
373 src_step_w,
374 src_offset_first_element_in_bytes,
375 dst_ptr,
376 dst_stride_x,
377 dst_step_x,
378 dst_stride_y,
379 dst_step_y,
380 dst_stride_z,
381 dst_step_z,
382 dst_offset_first_element_in_bytes);
383}
384#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
385
386#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
387/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x2
388 *
389 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
390 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
391 *
392 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
393 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
394 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
395 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
396 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
397 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
398 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
399 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
400 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
401 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
402 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
403 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
404 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
405 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
406 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
407 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
408 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
409 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
410 */
411__kernel void winograd_filter_transform_1x2_1x3_nchw(
412 TENSOR4D_DECLARATION(src),
413 TENSOR3D_DECLARATION(dst))
414{
415 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
416 src_stride_x,
417 src_step_x,
418 src_stride_y,
419 src_step_y,
420 src_stride_z,
421 src_step_z,
422 src_stride_w,
423 src_step_w,
424 src_offset_first_element_in_bytes,
425 dst_ptr,
426 dst_stride_x,
427 dst_step_x,
428 dst_stride_y,
429 dst_step_y,
430 dst_stride_z,
431 dst_step_z,
432 dst_offset_first_element_in_bytes);
433}
434
435/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x4
436 *
437 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
438 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
439 *
440 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
441 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
442 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
443 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
444 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
445 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
446 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
447 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
448 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
449 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
450 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
451 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
452 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
453 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
454 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
455 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
456 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
457 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
458 */
459__kernel void winograd_filter_transform_1x4_1x3_nchw(
460 TENSOR4D_DECLARATION(src),
461 TENSOR3D_DECLARATION(dst))
462{
463 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
464 src_stride_x,
465 src_step_x,
466 src_stride_y,
467 src_step_y,
468 src_stride_z,
469 src_step_z,
470 src_stride_w,
471 src_step_w,
472 src_offset_first_element_in_bytes,
473 dst_ptr,
474 dst_stride_x,
475 dst_step_x,
476 dst_stride_y,
477 dst_step_y,
478 dst_stride_z,
479 dst_step_z,
480 dst_offset_first_element_in_bytes);
481}
482#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
483
484/** This OpenCL kernel performs Winograd filter transform 3x3 when the data layout is NHWC and the output tile is 4x4
485 *
486 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
487 *
488 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
489 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
490 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
491 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
492 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
493 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
494 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
495 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
496 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
497 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
498 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
499 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
500 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
501 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
502 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
503 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
504 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
505 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
506 */
507__kernel void winograd_filter_transform_4x4_3x3_nhwc(
508 TENSOR4D_DECLARATION(src),
509 TENSOR3D_DECLARATION(dst))
510{
511 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
512
513 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
514
515 // Load the values from the input tensor
516 float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
517 float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
518 float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
519 float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
520 float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
521 float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
522 float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
523 float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
524 float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
525
526 // Transform the 3x3 tile in a 6x6 tile
527 float out00, out01, out02, out03, out04, out05;
528 float out10, out11, out12, out13, out14, out15;
529 float out20, out21, out22, out23, out24, out25;
530 float out30, out31, out32, out33, out34, out35;
531 float out40, out41, out42, out43, out44, out45;
532 float out50, out51, out52, out53, out54, out55;
533
534 out00 = out01 = out02 = out03 = out04 = out05 = 0.f;
535 out10 = out11 = out12 = out13 = out14 = out15 = 0.f;
536 out20 = out21 = out22 = out23 = out24 = out25 = 0.f;
537 out30 = out31 = out32 = out33 = out34 = out35 = 0.f;
538 out40 = out41 = out42 = out43 = out44 = out45 = 0.f;
539 out50 = out51 = out52 = out53 = out54 = out55 = 0.f;
540
541 // Row 0
542 out00 = (w00) / 16.f;
543 out01 = (-w00 - w01 - w02) / 24.f;
544 out02 = (-w00 + w01 - w02) / 24.f;
545 out03 = (w00 + 2.f * w01 + 4.f * w02) / 96.f;
546 out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f;
547 out05 = (w02) / 4.f;
548
549 // Row 1
550 out10 = (-w00 - w10 - w20) / 24.f;
551 out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f;
552 out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f;
553 out13 = (-w00 - w10 - w20 + 2.f * (-w01 - w11 - w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
554 out14 = (-w00 - w10 - w20 + 2.f * (w01 + w11 + w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
555 out15 = (-w02 - w12 - w22) / 6.f;
556
557 // Row 2
558 out20 = (-w00 + w10 - w20) / 24.f;
559 out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f;
560 out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f;
561 out23 = (-w00 + w10 - w20 + 2.f * (-w01 + w11 - w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
562 out24 = (-w00 + w10 - w20 + 2.f * (w01 - w11 + w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
563 out25 = (-w02 + w12 - w22) / 6.f;
564
565 // Row 3
566 out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f;
567 out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
568 out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
569 out33 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (w01 + 2.f * w11 + 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
570 out34 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (-w01 - 2.f * w11 - 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
571 out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f;
572
573 // Row 4
574 out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f;
575 out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
576 out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
577 out43 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (w01 - 2.f * w11 + 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
578 out44 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (-w01 + 2.f * w11 - 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
579 out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f;
580
581 // Row 5
582 out50 = (w20) / 4.f;
583 out51 = (-w20 - w21 - w22) / 6.f;
584 out52 = (-w20 + w21 - w22) / 6.f;
585 out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f;
586 out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f;
587 out55 = (w22);
588
589 int x0 = get_global_id(2); // idx filter
590 int y0 = get_global_id(0); // idx channel
591
592 // Get output address
593 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
594
595 // Store the values across the channels
596 *(__global float *)(dst_addr + 0 * dst_stride_z) = out00;
597 *(__global float *)(dst_addr + 1 * dst_stride_z) = out01;
598 *(__global float *)(dst_addr + 2 * dst_stride_z) = out02;
599 *(__global float *)(dst_addr + 3 * dst_stride_z) = out03;
600 *(__global float *)(dst_addr + 4 * dst_stride_z) = out04;
601 *(__global float *)(dst_addr + 5 * dst_stride_z) = out05;
602 *(__global float *)(dst_addr + 6 * dst_stride_z) = out10;
603 *(__global float *)(dst_addr + 7 * dst_stride_z) = out11;
604 *(__global float *)(dst_addr + 8 * dst_stride_z) = out12;
605 *(__global float *)(dst_addr + 9 * dst_stride_z) = out13;
606 *(__global float *)(dst_addr + 10 * dst_stride_z) = out14;
607 *(__global float *)(dst_addr + 11 * dst_stride_z) = out15;
608 *(__global float *)(dst_addr + 12 * dst_stride_z) = out20;
609 *(__global float *)(dst_addr + 13 * dst_stride_z) = out21;
610 *(__global float *)(dst_addr + 14 * dst_stride_z) = out22;
611 *(__global float *)(dst_addr + 15 * dst_stride_z) = out23;
612 *(__global float *)(dst_addr + 16 * dst_stride_z) = out24;
613 *(__global float *)(dst_addr + 17 * dst_stride_z) = out25;
614 *(__global float *)(dst_addr + 18 * dst_stride_z) = out30;
615 *(__global float *)(dst_addr + 19 * dst_stride_z) = out31;
616 *(__global float *)(dst_addr + 20 * dst_stride_z) = out32;
617 *(__global float *)(dst_addr + 21 * dst_stride_z) = out33;
618 *(__global float *)(dst_addr + 22 * dst_stride_z) = out34;
619 *(__global float *)(dst_addr + 23 * dst_stride_z) = out35;
620 *(__global float *)(dst_addr + 24 * dst_stride_z) = out40;
621 *(__global float *)(dst_addr + 25 * dst_stride_z) = out41;
622 *(__global float *)(dst_addr + 26 * dst_stride_z) = out42;
623 *(__global float *)(dst_addr + 27 * dst_stride_z) = out43;
624 *(__global float *)(dst_addr + 28 * dst_stride_z) = out44;
625 *(__global float *)(dst_addr + 29 * dst_stride_z) = out45;
626 *(__global float *)(dst_addr + 30 * dst_stride_z) = out50;
627 *(__global float *)(dst_addr + 31 * dst_stride_z) = out51;
628 *(__global float *)(dst_addr + 32 * dst_stride_z) = out52;
629 *(__global float *)(dst_addr + 33 * dst_stride_z) = out53;
630 *(__global float *)(dst_addr + 34 * dst_stride_z) = out54;
631 *(__global float *)(dst_addr + 35 * dst_stride_z) = out55;
632}
633/** This OpenCL kernel performs Winograd filter transform 5x5 when the data layout is NCHW and the output tile is 4x4
634 *
635 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
636 *
637 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
638 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
639 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
640 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
641 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
642 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
643 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
644 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
645 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
646 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
647 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
648 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
649 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
650 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
651 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
652 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
653 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
654 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
655 */
656__kernel void winograd_filter_transform_4x4_5x5_nchw(
657 TENSOR4D_DECLARATION(src),
658 TENSOR3D_DECLARATION(dst))
659{
660 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
661
662 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
663
664 // Load the values from the input tensor
665 const char stride_x = 4 * sizeof(float); // Used for accessing the last value in each row
666 const uchar8 stride_y = (uchar8)(0, 1, 2, 3, 4, 0, 0, 0) * (uchar8)src_stride_y;
667
668 float4 w00 = vload4(0, (__global float *)(src_addr + stride_y.s0));
669 float w01 = *((__global float *)(src_addr + stride_y.s0 + stride_x));
670 float4 w10 = vload4(0, (__global float *)(src_addr + stride_y.s1));
671 float w11 = *((__global float *)(src_addr + stride_y.s1 + stride_x));
672 float4 w20 = vload4(0, (__global float *)(src_addr + stride_y.s2));
673 float w21 = *((__global float *)(src_addr + stride_y.s2 + stride_x));
674 float4 w30 = vload4(0, (__global float *)(src_addr + stride_y.s3));
675 float w31 = *((__global float *)(src_addr + stride_y.s3 + stride_x));
676 float4 w40 = vload4(0, (__global float *)(src_addr + stride_y.s4));
677 float w41 = *((__global float *)(src_addr + stride_y.s4 + stride_x));
678
679 // Transform the 3x3 tile in a 8x8 tile
680 float8 out0 = 0.0f;
681 float8 out1 = 0.0f;
682 float8 out2 = 0.0f;
683 float8 out3 = 0.0f;
684 float8 out4 = 0.0f;
685 float8 out5 = 0.0f;
686 float8 out6 = 0.0f;
687 float8 out7 = 0.0f;
688
689 // Row 0
690 out0.s0 = w00.s0;
691 out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
692 out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
693 out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
694 out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
695 out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
696 out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
697 out0.s7 = w01;
698
699 // Row 1
700 out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
701 out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
702 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
703 out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
704 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
705 out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
706 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
707 out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
708 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
709 out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
710 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
711 out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
712 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
713 out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
714
715 // Row 2
716 out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
717 out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
718 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
719 out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
720 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
721 out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
722 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
723 out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
724 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
725 out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
726 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
727 out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
728 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
729 out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
730
731 // Row 3
732 out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
733 out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
734 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
735 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
736 out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
737 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
738 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
739 out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
740 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
741 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
742 out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
743 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
744 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
745 out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
746 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
747 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
748 out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
749 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
750 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
751 out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
752
753 // Row 4
754 out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
755 out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
756 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
757 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
758 out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
759 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
760 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
761 out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
762 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
763 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
764 out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
765 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
766 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
767 out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
768 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
769 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
770 out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
771 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
772 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
773 out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
774
775 // Row 5
776 out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
777 out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
778 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
779 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
780 out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
781 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
782 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
783 out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
784 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
785 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
786 out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
787 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
788 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
789 out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
790 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
791 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
792 out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
793 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
794 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
795 out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
796
797 // Row 6
798 out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
799 out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
800 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
801 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
802 out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
803 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
804 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
805 out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
806 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
807 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
808 out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
809 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
810 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
811 out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
812 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
813 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
814 out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
815 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
816 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
817 out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
818
819 // Row 7
820 out7.s0 = w40.s0;
821 out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
822 out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
823 out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
824 out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
825 out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
826 out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
827 out7.s7 = w41;
828
829 int z = get_global_id(2);
830 int x0 = z / SRC_DIM_Z; // idx filter
831 int y0 = z % SRC_DIM_Z; // idx channel
832
833 // Get output address
834 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
835
836 // Store the 64 values across the 64 channels
837 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
838 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
839 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
840 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
841 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
842 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
843 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
844 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
845 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
846 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
847 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
848 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
849 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
850 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
851 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
852 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
853 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
854 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
855 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
856 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
857 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
858 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
859 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
860 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
861 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
862 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
863 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
864 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
865 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
866 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
867 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
868 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
869 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
870 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
871 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
872 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
873 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
874 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
875 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
876 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
877 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
878 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
879 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
880 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
881 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
882 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
883 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
884 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
885 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
886 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
887 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
888 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
889 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
890 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
891 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
892 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
893 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
894 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
895 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
896 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
897 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
898 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
899 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
900 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
901}
902
903/** This OpenCL kernel performs Winograd filter transform 5x5 when the data layout is NHWC and the output tile is 4x4
904 *
905 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
906 *
907 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
908 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
909 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
910 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
911 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
912 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
913 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
914 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
915 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
916 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
917 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
918 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
919 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
920 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
921 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
922 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
923 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
924 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
925 */
926__kernel void winograd_filter_transform_4x4_5x5_nhwc(
927 TENSOR4D_DECLARATION(src),
928 TENSOR3D_DECLARATION(dst))
929{
930 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
931
932 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(float) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
933
934 // Load the values from the input tensor
935 float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
936 float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
937 float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
938 float w03 = *((__global float *)(src_addr + 0 * src_stride_z + 3 * src_stride_y));
939 float w04 = *((__global float *)(src_addr + 0 * src_stride_z + 4 * src_stride_y));
940 float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
941 float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
942 float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
943 float w13 = *((__global float *)(src_addr + 1 * src_stride_z + 3 * src_stride_y));
944 float w14 = *((__global float *)(src_addr + 1 * src_stride_z + 4 * src_stride_y));
945 float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
946 float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
947 float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
948 float w23 = *((__global float *)(src_addr + 2 * src_stride_z + 3 * src_stride_y));
949 float w24 = *((__global float *)(src_addr + 2 * src_stride_z + 4 * src_stride_y));
950 float w30 = *((__global float *)(src_addr + 3 * src_stride_z + 0 * src_stride_y));
951 float w31 = *((__global float *)(src_addr + 3 * src_stride_z + 1 * src_stride_y));
952 float w32 = *((__global float *)(src_addr + 3 * src_stride_z + 2 * src_stride_y));
953 float w33 = *((__global float *)(src_addr + 3 * src_stride_z + 3 * src_stride_y));
954 float w34 = *((__global float *)(src_addr + 3 * src_stride_z + 4 * src_stride_y));
955 float w40 = *((__global float *)(src_addr + 4 * src_stride_z + 0 * src_stride_y));
956 float w41 = *((__global float *)(src_addr + 4 * src_stride_z + 1 * src_stride_y));
957 float w42 = *((__global float *)(src_addr + 4 * src_stride_z + 2 * src_stride_y));
958 float w43 = *((__global float *)(src_addr + 4 * src_stride_z + 3 * src_stride_y));
959 float w44 = *((__global float *)(src_addr + 4 * src_stride_z + 4 * src_stride_y));
960
961 // Transform the 3x3 tile in a 8x8 tile
962 float8 out0 = 0.0f;
963 float8 out1 = 0.0f;
964 float8 out2 = 0.0f;
965 float8 out3 = 0.0f;
966 float8 out4 = 0.0f;
967 float8 out5 = 0.0f;
968 float8 out6 = 0.0f;
969 float8 out7 = 0.0f;
970
971 // Row 0
972 out0.s0 = w00;
973 out0.s1 = -2.f * (w00 + w01 + w02 + w03 + w04) / 9.f;
974 out0.s2 = -2.f * (w00 - w01 + w02 - w03 + w04) / 9.f;
975 out0.s3 = (w00 + 2.f * w01 + 4.f * w02 + 8.f * w03 + 16.f * w04) / 90.f;
976 out0.s4 = (w00 - 2.f * w01 + 4.f * w02 - 8.f * w03 + 16.f * w04) / 90.f;
977 out0.s5 = (16.f * w00 + 8.f * w01 + 4.f * w02 + 2.f * w03 + w04) / 180.f;
978 out0.s6 = (16.f * w00 - 8.f * w01 + 4.f * w02 - 2.f * w03 + w04) / 180.f;
979 out0.s7 = w04;
980
981 // Row 1
982 out1.s0 = -2.f * (w00 + w10 + w20 + w30 + w40) / 9.f;
983 out1.s1 = 4.f * ((w00 + w10 + w20 + w30 + w40) + (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) + (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
984 out1.s2 = 4.f * ((w00 + w10 + w20 + w30 + w40) - (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) - (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
985 out1.s3 = -((w00 + w10 + w20 + w30 + w40) + 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
986 (w04 + w14 + w24 + w34 + w44)) / 405.f;
987 out1.s4 = -((w00 + w10 + w20 + w30 + w40) - 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
988 (w04 + w14 + w24 + w34 + w44)) / 405.f;
989 out1.s5 = -(16.f * (w00 + w10 + w20 + w30 + w40) + 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 2.f * (w03 + w13 + w23 + w33 + w43) +
990 (w04 + w14 + w24 + w34 + w44)) / 810.f;
991 out1.s6 = -(16.f * (w00 + w10 + w20 + w30 + w40) - 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 2.f * (w03 + w13 + w23 + w33 + w43) +
992 (w04 + w14 + w24 + w34 + w44)) / 810.f;
993 out1.s7 = -2.f * (w04 + w14 + w24 + w34 + w44) / 9.f;
994
995 // Row 2
996 out2.s0 = -2.f * (w00 - w10 + w20 - w30 + w40) / 9.f;
997 out2.s1 = 4.f * ((w00 - w10 + w20 - w30 + w40) + (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) + (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
998 out2.s2 = 4.f * ((w00 - w10 + w20 - w30 + w40) - (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) - (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
999 out2.s3 = -((w00 - w10 + w20 - w30 + w40) + 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
1000 (w04 - w14 + w24 - w34 + w44)) / 405.f;
1001 out2.s4 = -((w00 - w10 + w20 - w30 + w40) - 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
1002 (w04 - w14 + w24 - w34 + w44)) / 405.f;
1003 out2.s5 = -(16.f * (w00 - w10 + w20 - w30 + w40) + 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 2.f * (w03 - w13 + w23 - w33 + w43) +
1004 (w04 - w14 + w24 - w34 + w44)) / 810.f;
1005 out2.s6 = -(16.f * (w00 - w10 + w20 - w30 + w40) - 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 2.f * (w03 - w13 + w23 - w33 + w43) +
1006 (w04 - w14 + w24 - w34 + w44)) / 810.f;
1007 out2.s7 = -2.f * (w04 - w14 + w24 - w34 + w44) / 9.f;
1008
1009 // Row 3
1010 out3.s0 = (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) / 90.f;
1011 out3.s1 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) +
1012 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
1013 out3.s2 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) -
1014 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
1015 out3.s3 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 8.f
1016 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
1017 out3.s4 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 8.f
1018 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
1019 out3.s5 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
1020 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
1021 out3.s6 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
1022 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
1023 out3.s7 = (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44) / 90.f;
1024
1025 // Row 4
1026 out4.s0 = (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) / 90.f;
1027 out4.s1 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) +
1028 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
1029 out4.s2 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) -
1030 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
1031 out4.s3 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 8.f
1032 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
1033 out4.s4 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 8.f
1034 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
1035 out4.s5 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
1036 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
1037 out4.s6 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
1038 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
1039 out4.s7 = (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44) / 90.f;
1040
1041 // Row 5
1042 out5.s0 = (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) / 180.f;
1043 out5.s1 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) +
1044 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
1045 out5.s2 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) -
1046 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
1047 out5.s3 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 8.f
1048 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
1049 out5.s4 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 8.f
1050 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
1051 out5.s5 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
1052 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
1053 out5.s6 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
1054 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
1055 out5.s7 = (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44) / 180.f;
1056
1057 // Row 6
1058 out6.s0 = (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) / 180.f;
1059 out6.s1 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) +
1060 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
1061 out6.s2 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) -
1062 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
1063 out6.s3 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 8.f
1064 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
1065 out6.s4 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 8.f
1066 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
1067 out6.s5 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
1068 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
1069 out6.s6 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
1070 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
1071 out6.s7 = (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44) / 180.f;
1072
1073 // Row 7
1074 out7.s0 = w40;
1075 out7.s1 = -2.f * (w40 + w41 + w42 + w43 + w44) / 9.f;
1076 out7.s2 = -2.f * (w40 - w41 + w42 - w43 + w44) / 9.f;
1077 out7.s3 = (w40 + 2.f * w41 + 4.f * w42 + 8.f * w43 + 16.f * w44) / 90.f;
1078 out7.s4 = (w40 - 2.f * w41 + 4.f * w42 - 8.f * w43 + 16.f * w44) / 90.f;
1079 out7.s5 = (16.f * w40 + 8.f * w41 + 4.f * w42 + 2.f * w43 + w44) / 180.f;
1080 out7.s6 = (16.f * w40 - 8.f * w41 + 4.f * w42 - 2.f * w43 + w44) / 180.f;
1081 out7.s7 = w44;
1082
1083 int x0 = get_global_id(2); // idx filter
1084 int y0 = get_global_id(0); // idx channel
1085
1086 // Get output address
1087 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
1088
1089 // Store the 64 values across the 64 channels
1090 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
1091 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
1092 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
1093 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
1094 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
1095 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
1096 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
1097 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
1098 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
1099 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
1100 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
1101 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
1102 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
1103 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
1104 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
1105 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
1106 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
1107 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
1108 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
1109 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
1110 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
1111 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
1112 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
1113 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
1114 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
1115 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
1116 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
1117 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
1118 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
1119 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
1120 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
1121 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
1122 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
1123 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
1124 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
1125 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
1126 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
1127 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
1128 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
1129 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
1130 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
1131 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
1132 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
1133 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
1134 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
1135 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
1136 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
1137 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
1138 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
1139 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
1140 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
1141 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
1142 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
1143 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
1144 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
1145 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
1146 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
1147 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
1148 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
1149 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
1150 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
1151 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
1152 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
1153 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
1154}
1155#endif // defined(SRC_DIM_Z)