blob: 3b9b1e918e7e878676b0083a1074c30dfe696e1d [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26#if defined(SRC_DIM_Z)
27
28/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 2x2/2x1/1x2
29 *
30 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
31 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
32 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010033 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010034 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010035 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010036 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
37 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
38 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
39 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
40 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
41 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
42 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
43 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
44 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
45 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
46 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
47 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
48 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
49 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
50 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
51 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
52 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
53 */
54__kernel void winograd_filter_transform_2x2_3x3_nchw(
55 TENSOR4D_DECLARATION(src),
56 TENSOR3D_DECLARATION(dst))
57{
58 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
59
60 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
61
62 // Load the values from the input tensor
63#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010064 VEC_DATA_TYPE(DATA_TYPE, 3)
65 w0 = vload3(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010066#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010067 VEC_DATA_TYPE(DATA_TYPE, 3)
68 w0 = (VEC_DATA_TYPE(DATA_TYPE, 3))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
69 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
70 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010071#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010072 VEC_DATA_TYPE(DATA_TYPE, 3)
73 w0 = vload3(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
74 VEC_DATA_TYPE(DATA_TYPE, 3)
75 w1 = vload3(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
76 VEC_DATA_TYPE(DATA_TYPE, 3)
77 w2 = vload3(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010078#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
79
80 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010081 VEC_DATA_TYPE(DATA_TYPE, 4)
82 out0 = 0.0f;
83 out0.s0 = (w0.s0);
84 out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
85 out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
86 out0.s3 = (w0.s2);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010087
88#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
89 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010090 VEC_DATA_TYPE(DATA_TYPE, 4)
91 out1 = 0.0f;
92 out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
93 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
94 out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
95 out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010096
97 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010098 VEC_DATA_TYPE(DATA_TYPE, 4)
99 out2 = 0.0f;
100 out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
101 out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
102 out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
103 out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100104
105 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100106 VEC_DATA_TYPE(DATA_TYPE, 4)
107 out3 = 0.0f;
108 out3.s0 = (w2.s0);
109 out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
110 out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
111 out3.s3 = (w2.s2);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100112#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
113
114 int z = get_global_id(2);
115 int x0 = z / SRC_DIM_Z; // idx filter
116 int y0 = z % SRC_DIM_Z; // idx channel
117
118 // Get output address
119 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
120
121 // Store the values across the channels
122 // 16 channels for 3x3 kernels
123 // 4 channels for 3x1 or 1x3 kernels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100124 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
125 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
126 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
127 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100128
129#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100130 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out1.s0;
131 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out1.s1;
132 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out1.s2;
133 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out1.s3;
134 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out2.s0;
135 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out2.s1;
136 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out2.s2;
137 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out2.s3;
138 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out3.s0;
139 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out3.s1;
140 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out3.s2;
141 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out3.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100142#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
143}
144
145/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 4x4/4x1/1x4
146 *
147 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
148 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
149 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100150 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100151 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100152 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100153 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
154 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
155 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
156 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
157 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
158 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
159 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
160 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
161 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
162 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
163 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
164 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
165 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
166 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
167 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
168 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
169 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
170 */
171__kernel void winograd_filter_transform_4x4_3x3_nchw(
172 TENSOR4D_DECLARATION(src),
173 TENSOR3D_DECLARATION(dst))
174{
175 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
176
177 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
178
179 // Load the values from the input tensor
180#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100181 VEC_DATA_TYPE(DATA_TYPE, 3)
182 w0 = vload3(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100183#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100184 VEC_DATA_TYPE(DATA_TYPE, 3)
185 w0 = (VEC_DATA_TYPE(DATA_TYPE, 3))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
186 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
187 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100188#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100189 VEC_DATA_TYPE(DATA_TYPE, 3)
190 w0 = vload3(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
191 VEC_DATA_TYPE(DATA_TYPE, 3)
192 w1 = vload3(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
193 VEC_DATA_TYPE(DATA_TYPE, 3)
194 w2 = vload3(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100195#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
196
197 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100198 VEC_DATA_TYPE(DATA_TYPE, 8)
199 out0 = 0.0f;
200 out0.s0 = (w0.s0) / 16.f;
201 out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
202 out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
203 out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
204 out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
205 out0.s5 = (w0.s2) / 4.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100206
207#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
208 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100209 VEC_DATA_TYPE(DATA_TYPE, 8)
210 out1 = 0.0f;
211 out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
212 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
213 out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
214 out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
215 out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
216 out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100217
218 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100219 VEC_DATA_TYPE(DATA_TYPE, 8)
220 out2 = 0.0f;
221 out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
222 out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
223 out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
224 out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
225 out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
226 out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100227
228 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100229 VEC_DATA_TYPE(DATA_TYPE, 8)
230 out3 = 0.0f;
231 out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
232 out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
233 out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
234 out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
235 out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
236 out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100237
238 // Row 4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100239 VEC_DATA_TYPE(DATA_TYPE, 8)
240 out4 = 0.0f;
241 out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
242 out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
243 out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
244 out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
245 out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
246 out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100247
248 // Row 5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100249 VEC_DATA_TYPE(DATA_TYPE, 8)
250 out5 = 0.0f;
251 out5.s0 = (w2.s0) / 4.f;
252 out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
253 out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
254 out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
255 out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
256 out5.s5 = (w2.s2);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100257#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
258
259 int z = get_global_id(2);
260 int x0 = z / SRC_DIM_Z; // idx filter
261 int y0 = z % SRC_DIM_Z; // idx channel
262
263 // Get output address
264 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
265
266 // Store the values across the channels
267 // 36 channels for 3x3 kernels
268 // 6 channels for 3x1 or 1x3 kernels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100269 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
270 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
271 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
272 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
273 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out0.s4;
274 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out0.s5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100275
276#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100277 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out1.s0;
278 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out1.s1;
279 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out1.s2;
280 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out1.s3;
281 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out1.s4;
282 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out1.s5;
283 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out2.s0;
284 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out2.s1;
285 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out2.s2;
286 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out2.s3;
287 *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out2.s4;
288 *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out2.s5;
289 *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out3.s0;
290 *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out3.s1;
291 *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out3.s2;
292 *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out3.s3;
293 *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out3.s4;
294 *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out3.s5;
295 *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out4.s0;
296 *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out4.s1;
297 *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out4.s2;
298 *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out4.s3;
299 *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out4.s4;
300 *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out4.s5;
301 *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out5.s0;
302 *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out5.s1;
303 *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out5.s2;
304 *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out5.s3;
305 *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out5.s4;
306 *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out5.s5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100307#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
308}
309
Giorgio Arena149fdf32018-07-04 17:03:33 +0100310/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NHWC and the output tile is 4x4/4x1/1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100311 *
312 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
Giorgio Arena149fdf32018-07-04 17:03:33 +0100313 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
314 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100315 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100316 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100317 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100318 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
319 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
320 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
321 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
322 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
323 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
324 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
325 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
326 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
327 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
328 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
329 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
330 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
331 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
332 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
333 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
334 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
335 */
336__kernel void winograd_filter_transform_4x4_3x3_nhwc(
337 TENSOR4D_DECLARATION(src),
338 TENSOR3D_DECLARATION(dst))
339{
340 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
341
342 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
343
344 // Load the values from the input tensor
Giorgio Arena149fdf32018-07-04 17:03:33 +0100345#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100346 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
347 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
348 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
Giorgio Arena149fdf32018-07-04 17:03:33 +0100349#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100350 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
351 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
352 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
Giorgio Arena149fdf32018-07-04 17:03:33 +0100353#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100354 DATA_TYPE w10 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
355 DATA_TYPE w11 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
356 DATA_TYPE w12 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
357 DATA_TYPE w20 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
358 DATA_TYPE w21 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
359 DATA_TYPE w22 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
Giorgio Arena149fdf32018-07-04 17:03:33 +0100360#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
361#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100362
363 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100364 DATA_TYPE out00, out01, out02, out03, out04, out05;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100365 out00 = (w00) / 16.f;
366 out01 = (-w00 - w01 - w02) / 24.f;
367 out02 = (-w00 + w01 - w02) / 24.f;
368 out03 = (w00 + 2.f * w01 + 4.f * w02) / 96.f;
369 out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f;
370 out05 = (w02) / 4.f;
371
Giorgio Arena149fdf32018-07-04 17:03:33 +0100372#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100373 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100374 DATA_TYPE out10, out11, out12, out13, out14, out15;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100375 out10 = (-w00 - w10 - w20) / 24.f;
376 out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f;
377 out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f;
378 out13 = (-w00 - w10 - w20 + 2.f * (-w01 - w11 - w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
379 out14 = (-w00 - w10 - w20 + 2.f * (w01 + w11 + w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
380 out15 = (-w02 - w12 - w22) / 6.f;
381
382 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100383 DATA_TYPE out20, out21, out22, out23, out24, out25;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100384 out20 = (-w00 + w10 - w20) / 24.f;
385 out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f;
386 out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f;
387 out23 = (-w00 + w10 - w20 + 2.f * (-w01 + w11 - w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
388 out24 = (-w00 + w10 - w20 + 2.f * (w01 - w11 + w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
389 out25 = (-w02 + w12 - w22) / 6.f;
390
391 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100392 DATA_TYPE out30, out31, out32, out33, out34, out35;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100393 out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f;
394 out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
395 out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
396 out33 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (w01 + 2.f * w11 + 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
397 out34 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (-w01 - 2.f * w11 - 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
398 out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f;
399
400 // Row 4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100401 DATA_TYPE out40, out41, out42, out43, out44, out45;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100402 out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f;
403 out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
404 out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
405 out43 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (w01 - 2.f * w11 + 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
406 out44 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (-w01 + 2.f * w11 - 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
407 out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f;
408
409 // Row 5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100410 DATA_TYPE out50, out51, out52, out53, out54, out55;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100411 out50 = (w20) / 4.f;
412 out51 = (-w20 - w21 - w22) / 6.f;
413 out52 = (-w20 + w21 - w22) / 6.f;
414 out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f;
415 out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f;
416 out55 = (w22);
Giorgio Arena149fdf32018-07-04 17:03:33 +0100417#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100418
419 int x0 = get_global_id(2); // idx filter
420 int y0 = get_global_id(0); // idx channel
421
422 // Get output address
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100423 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(DATA_TYPE) + y0 * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100424
425 // Store the values across the channels
Giorgio Arena149fdf32018-07-04 17:03:33 +0100426 // 36 channels for 3x3 kernels
427 // 6 channels for 3x1 or 1x3 kernels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100428 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out00;
429 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out01;
430 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out02;
431 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out03;
432 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out04;
433 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out05;
Giorgio Arena149fdf32018-07-04 17:03:33 +0100434#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100435 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out10;
436 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out11;
437 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out12;
438 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out13;
439 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out14;
440 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out15;
441 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out20;
442 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out21;
443 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out22;
444 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out23;
445 *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out24;
446 *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out25;
447 *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out30;
448 *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out31;
449 *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out32;
450 *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out33;
451 *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out34;
452 *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out35;
453 *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out40;
454 *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out41;
455 *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out42;
456 *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out43;
457 *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out44;
458 *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out45;
459 *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out50;
460 *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out51;
461 *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out52;
462 *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out53;
463 *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out54;
464 *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out55;
Giorgio Arena149fdf32018-07-04 17:03:33 +0100465#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100466}
Giorgio Arena149fdf32018-07-04 17:03:33 +0100467
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100468/** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NCHW and the output tile is 4x4/4x1 or 1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100469 *
470 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
471 *
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100472 * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
473 * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100474 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100475 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100476 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100477 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
478 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
479 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
480 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
481 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
482 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
483 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
484 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
485 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
486 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
487 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
488 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
489 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
490 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
491 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
492 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
493 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
494 */
495__kernel void winograd_filter_transform_4x4_5x5_nchw(
496 TENSOR4D_DECLARATION(src),
497 TENSOR3D_DECLARATION(dst))
498{
499 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
500
501 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
502
503 // Load the values from the input tensor
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100504#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100505 VEC_DATA_TYPE(DATA_TYPE, 4)
506 w00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
507 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_y) + 4);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100508#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100509 VEC_DATA_TYPE(DATA_TYPE, 4)
510 w00 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
511 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
512 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
513 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
514 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100515#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100516 VEC_DATA_TYPE(DATA_TYPE, 4)
517 w00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
518 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_y) + 4);
519 VEC_DATA_TYPE(DATA_TYPE, 4)
520 w10 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
521 DATA_TYPE w11 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y) + 4);
522 VEC_DATA_TYPE(DATA_TYPE, 4)
523 w20 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
524 DATA_TYPE w21 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y) + 4);
525 VEC_DATA_TYPE(DATA_TYPE, 4)
526 w30 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
527 DATA_TYPE w31 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y) + 4);
528 VEC_DATA_TYPE(DATA_TYPE, 4)
529 w40 = vload4(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
530 DATA_TYPE w41 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y) + 4);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100531#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100532
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100533 // Transform the input tile
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100534
535 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100536 VEC_DATA_TYPE(DATA_TYPE, 8)
537 out0 = 0.0f;
538 out0.s0 = w00.s0;
539 out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
540 out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
541 out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
542 out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
543 out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
544 out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
545 out0.s7 = w01;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100546
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100547#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100548 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100549 VEC_DATA_TYPE(DATA_TYPE, 8)
550 out1 = 0.0f;
551 out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
552 out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
553 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
554 out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
555 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
556 out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
557 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
558 out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
559 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
560 out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
561 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
562 out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
563 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
564 out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100565
566 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100567 VEC_DATA_TYPE(DATA_TYPE, 8)
568 out2 = 0.0f;
569 out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
570 out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
571 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
572 out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
573 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
574 out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
575 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
576 out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
577 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
578 out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
579 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
580 out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
581 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
582 out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100583
584 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100585 VEC_DATA_TYPE(DATA_TYPE, 8)
586 out3 = 0.0f;
587 out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
588 out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
589 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
590 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
591 out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
592 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
593 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
594 out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
595 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
596 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
597 out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
598 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
599 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
600 out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
601 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
602 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
603 out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
604 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
605 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
606 out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100607
608 // Row 4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100609 VEC_DATA_TYPE(DATA_TYPE, 8)
610 out4 = 0.0f;
611 out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
612 out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
613 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
614 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
615 out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
616 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
617 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
618 out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
619 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
620 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
621 out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
622 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
623 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
624 out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
625 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
626 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
627 out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
628 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
629 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
630 out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100631
632 // Row 5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100633 VEC_DATA_TYPE(DATA_TYPE, 8)
634 out5 = 0.0f;
635 out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
636 out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
637 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
638 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
639 out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
640 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
641 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
642 out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
643 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
644 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
645 out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
646 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
647 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
648 out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
649 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
650 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
651 out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
652 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
653 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
654 out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100655
656 // Row 6
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100657 VEC_DATA_TYPE(DATA_TYPE, 8)
658 out6 = 0.0f;
659 out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
660 out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
661 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
662 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
663 out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
664 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
665 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
666 out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
667 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
668 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
669 out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
670 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
671 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
672 out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
673 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
674 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
675 out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
676 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
677 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
678 out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100679
680 // Row 7
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100681 VEC_DATA_TYPE(DATA_TYPE, 8)
682 out7 = 0.0f;
683 out7.s0 = w40.s0;
684 out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
685 out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
686 out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
687 out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
688 out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
689 out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
690 out7.s7 = w41;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100691#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100692
693 int z = get_global_id(2);
694 int x0 = z / SRC_DIM_Z; // idx filter
695 int y0 = z % SRC_DIM_Z; // idx channel
696
697 // Get output address
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100698 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(DATA_TYPE) + y0 * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100699
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100700 // Store the values across the channels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100701 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
702 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
703 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
704 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
705 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out0.s4;
706 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out0.s5;
707 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out0.s6;
708 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out0.s7;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100709
710#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100711 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out1.s0;
712 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out1.s1;
713 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out1.s2;
714 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out1.s3;
715 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out1.s4;
716 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out1.s5;
717 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out1.s6;
718 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out1.s7;
719 *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out2.s0;
720 *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out2.s1;
721 *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out2.s2;
722 *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out2.s3;
723 *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out2.s4;
724 *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out2.s5;
725 *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out2.s6;
726 *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out2.s7;
727 *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out3.s0;
728 *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out3.s1;
729 *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out3.s2;
730 *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out3.s3;
731 *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out3.s4;
732 *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out3.s5;
733 *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out3.s6;
734 *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out3.s7;
735 *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out4.s0;
736 *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out4.s1;
737 *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out4.s2;
738 *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out4.s3;
739 *(__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z) = out4.s4;
740 *(__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z) = out4.s5;
741 *(__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z) = out4.s6;
742 *(__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z) = out4.s7;
743 *(__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z) = out5.s0;
744 *(__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z) = out5.s1;
745 *(__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z) = out5.s2;
746 *(__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z) = out5.s3;
747 *(__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z) = out5.s4;
748 *(__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z) = out5.s5;
749 *(__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z) = out5.s6;
750 *(__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z) = out5.s7;
751 *(__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z) = out6.s0;
752 *(__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z) = out6.s1;
753 *(__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z) = out6.s2;
754 *(__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z) = out6.s3;
755 *(__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z) = out6.s4;
756 *(__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z) = out6.s5;
757 *(__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z) = out6.s6;
758 *(__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z) = out6.s7;
759 *(__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z) = out7.s0;
760 *(__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z) = out7.s1;
761 *(__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z) = out7.s2;
762 *(__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z) = out7.s3;
763 *(__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z) = out7.s4;
764 *(__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z) = out7.s5;
765 *(__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z) = out7.s6;
766 *(__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z) = out7.s7;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100767#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100768}
769
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100770/** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NHWC and the output tile is 4x4/4x1 or 1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100771 *
772 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100773 * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
774 * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100775 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100776 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100777 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100778 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
779 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
780 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
781 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
782 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
783 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
784 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
785 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
786 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
787 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
788 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
789 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
790 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
791 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
792 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
793 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
794 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
795 */
796__kernel void winograd_filter_transform_4x4_5x5_nhwc(
797 TENSOR4D_DECLARATION(src),
798 TENSOR3D_DECLARATION(dst))
799{
800 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
801
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100802 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(DATA_TYPE) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100803
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100804#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100805 // Load the values from the input tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100806 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
807 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
808 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
809 DATA_TYPE w03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
810 DATA_TYPE w04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100811#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
812 // Load the values from the input tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100813 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
814 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
815 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
816 DATA_TYPE w03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
817 DATA_TYPE w04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100818#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
819
820#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100821 DATA_TYPE w10 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
822 DATA_TYPE w11 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
823 DATA_TYPE w12 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
824 DATA_TYPE w13 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 3 * src_stride_y));
825 DATA_TYPE w14 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 4 * src_stride_y));
826 DATA_TYPE w20 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
827 DATA_TYPE w21 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
828 DATA_TYPE w22 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
829 DATA_TYPE w23 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 3 * src_stride_y));
830 DATA_TYPE w24 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 4 * src_stride_y));
831 DATA_TYPE w30 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 0 * src_stride_y));
832 DATA_TYPE w31 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 1 * src_stride_y));
833 DATA_TYPE w32 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 2 * src_stride_y));
834 DATA_TYPE w33 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 3 * src_stride_y));
835 DATA_TYPE w34 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 4 * src_stride_y));
836 DATA_TYPE w40 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 0 * src_stride_y));
837 DATA_TYPE w41 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 1 * src_stride_y));
838 DATA_TYPE w42 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 2 * src_stride_y));
839 DATA_TYPE w43 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 3 * src_stride_y));
840 DATA_TYPE w44 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 4 * src_stride_y));
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100841#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100842
843 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100844 VEC_DATA_TYPE(DATA_TYPE, 8)
845 out0 = 0.0f;
846 out0.s0 = w00;
847 out0.s1 = -2.f * (w00 + w01 + w02 + w03 + w04) / 9.f;
848 out0.s2 = -2.f * (w00 - w01 + w02 - w03 + w04) / 9.f;
849 out0.s3 = (w00 + 2.f * w01 + 4.f * w02 + 8.f * w03 + 16.f * w04) / 90.f;
850 out0.s4 = (w00 - 2.f * w01 + 4.f * w02 - 8.f * w03 + 16.f * w04) / 90.f;
851 out0.s5 = (16.f * w00 + 8.f * w01 + 4.f * w02 + 2.f * w03 + w04) / 180.f;
852 out0.s6 = (16.f * w00 - 8.f * w01 + 4.f * w02 - 2.f * w03 + w04) / 180.f;
853 out0.s7 = w04;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100854
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100855#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100856 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100857 VEC_DATA_TYPE(DATA_TYPE, 8)
858 out1 = 0.0f;
859 out1.s0 = -2.f * (w00 + w10 + w20 + w30 + w40) / 9.f;
860 out1.s1 = 4.f * ((w00 + w10 + w20 + w30 + w40) + (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) + (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
861 out1.s2 = 4.f * ((w00 + w10 + w20 + w30 + w40) - (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) - (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
862 out1.s3 = -((w00 + w10 + w20 + w30 + w40) + 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
863 (w04 + w14 + w24 + w34 + w44)) / 405.f;
864 out1.s4 = -((w00 + w10 + w20 + w30 + w40) - 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
865 (w04 + w14 + w24 + w34 + w44)) / 405.f;
866 out1.s5 = -(16.f * (w00 + w10 + w20 + w30 + w40) + 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 2.f * (w03 + w13 + w23 + w33 + w43) +
867 (w04 + w14 + w24 + w34 + w44)) / 810.f;
868 out1.s6 = -(16.f * (w00 + w10 + w20 + w30 + w40) - 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 2.f * (w03 + w13 + w23 + w33 + w43) +
869 (w04 + w14 + w24 + w34 + w44)) / 810.f;
870 out1.s7 = -2.f * (w04 + w14 + w24 + w34 + w44) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100871
872 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100873 VEC_DATA_TYPE(DATA_TYPE, 8)
874 out2 = 0.0f;
875 out2.s0 = -2.f * (w00 - w10 + w20 - w30 + w40) / 9.f;
876 out2.s1 = 4.f * ((w00 - w10 + w20 - w30 + w40) + (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) + (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
877 out2.s2 = 4.f * ((w00 - w10 + w20 - w30 + w40) - (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) - (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
878 out2.s3 = -((w00 - w10 + w20 - w30 + w40) + 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
879 (w04 - w14 + w24 - w34 + w44)) / 405.f;
880 out2.s4 = -((w00 - w10 + w20 - w30 + w40) - 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
881 (w04 - w14 + w24 - w34 + w44)) / 405.f;
882 out2.s5 = -(16.f * (w00 - w10 + w20 - w30 + w40) + 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 2.f * (w03 - w13 + w23 - w33 + w43) +
883 (w04 - w14 + w24 - w34 + w44)) / 810.f;
884 out2.s6 = -(16.f * (w00 - w10 + w20 - w30 + w40) - 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 2.f * (w03 - w13 + w23 - w33 + w43) +
885 (w04 - w14 + w24 - w34 + w44)) / 810.f;
886 out2.s7 = -2.f * (w04 - w14 + w24 - w34 + w44) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100887
888 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100889 VEC_DATA_TYPE(DATA_TYPE, 8)
890 out3 = 0.0f;
891 out3.s0 = (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) / 90.f;
892 out3.s1 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) +
893 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
894 out3.s2 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) -
895 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
896 out3.s3 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 8.f
897 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
898 out3.s4 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 8.f
899 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
900 out3.s5 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
901 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
902 out3.s6 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
903 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
904 out3.s7 = (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100905
906 // Row 4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100907 VEC_DATA_TYPE(DATA_TYPE, 8)
908 out4 = 0.0f;
909 out4.s0 = (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) / 90.f;
910 out4.s1 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) +
911 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
912 out4.s2 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) -
913 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
914 out4.s3 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 8.f
915 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
916 out4.s4 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 8.f
917 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
918 out4.s5 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
919 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
920 out4.s6 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
921 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
922 out4.s7 = (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100923
924 // Row 5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100925 VEC_DATA_TYPE(DATA_TYPE, 8)
926 out5 = 0.0f;
927 out5.s0 = (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) / 180.f;
928 out5.s1 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) +
929 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
930 out5.s2 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) -
931 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
932 out5.s3 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 8.f
933 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
934 out5.s4 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 8.f
935 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
936 out5.s5 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
937 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
938 out5.s6 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
939 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
940 out5.s7 = (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100941
942 // Row 6
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100943 VEC_DATA_TYPE(DATA_TYPE, 8)
944 out6 = 0.0f;
945 out6.s0 = (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) / 180.f;
946 out6.s1 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) +
947 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
948 out6.s2 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) -
949 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
950 out6.s3 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 8.f
951 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
952 out6.s4 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 8.f
953 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
954 out6.s5 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
955 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
956 out6.s6 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
957 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
958 out6.s7 = (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100959
960 // Row 7
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100961 VEC_DATA_TYPE(DATA_TYPE, 8)
962 out7 = 0.0f;
963 out7.s0 = w40;
964 out7.s1 = -2.f * (w40 + w41 + w42 + w43 + w44) / 9.f;
965 out7.s2 = -2.f * (w40 - w41 + w42 - w43 + w44) / 9.f;
966 out7.s3 = (w40 + 2.f * w41 + 4.f * w42 + 8.f * w43 + 16.f * w44) / 90.f;
967 out7.s4 = (w40 - 2.f * w41 + 4.f * w42 - 8.f * w43 + 16.f * w44) / 90.f;
968 out7.s5 = (16.f * w40 + 8.f * w41 + 4.f * w42 + 2.f * w43 + w44) / 180.f;
969 out7.s6 = (16.f * w40 - 8.f * w41 + 4.f * w42 - 2.f * w43 + w44) / 180.f;
970 out7.s7 = w44;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100971#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100972
973 int x0 = get_global_id(2); // idx filter
974 int y0 = get_global_id(0); // idx channel
975
976 // Get output address
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100977 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(DATA_TYPE) + y0 * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100978
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100979 // Store the values across the channels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100980 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
981 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
982 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
983 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
984 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out0.s4;
985 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out0.s5;
986 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out0.s6;
987 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out0.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100988
989#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100990 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out1.s0;
991 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out1.s1;
992 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out1.s2;
993 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out1.s3;
994 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out1.s4;
995 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out1.s5;
996 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out1.s6;
997 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out1.s7;
998 *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out2.s0;
999 *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out2.s1;
1000 *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out2.s2;
1001 *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out2.s3;
1002 *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out2.s4;
1003 *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out2.s5;
1004 *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out2.s6;
1005 *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out2.s7;
1006 *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out3.s0;
1007 *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out3.s1;
1008 *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out3.s2;
1009 *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out3.s3;
1010 *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out3.s4;
1011 *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out3.s5;
1012 *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out3.s6;
1013 *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out3.s7;
1014 *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out4.s0;
1015 *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out4.s1;
1016 *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out4.s2;
1017 *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out4.s3;
1018 *(__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z) = out4.s4;
1019 *(__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z) = out4.s5;
1020 *(__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z) = out4.s6;
1021 *(__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z) = out4.s7;
1022 *(__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z) = out5.s0;
1023 *(__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z) = out5.s1;
1024 *(__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z) = out5.s2;
1025 *(__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z) = out5.s3;
1026 *(__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z) = out5.s4;
1027 *(__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z) = out5.s5;
1028 *(__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z) = out5.s6;
1029 *(__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z) = out5.s7;
1030 *(__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z) = out6.s0;
1031 *(__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z) = out6.s1;
1032 *(__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z) = out6.s2;
1033 *(__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z) = out6.s3;
1034 *(__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z) = out6.s4;
1035 *(__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z) = out6.s5;
1036 *(__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z) = out6.s6;
1037 *(__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z) = out6.s7;
1038 *(__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z) = out7.s0;
1039 *(__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z) = out7.s1;
1040 *(__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z) = out7.s2;
1041 *(__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z) = out7.s3;
1042 *(__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z) = out7.s4;
1043 *(__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z) = out7.s5;
1044 *(__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z) = out7.s6;
1045 *(__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z) = out7.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001046#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001047}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001048#endif // defined(SRC_DIM_Z)
1049
1050#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1051/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 2x1
1052 *
1053 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1054 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001055 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001056 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001057 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001058 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1059 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1060 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1061 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1062 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1063 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1064 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1065 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1066 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1067 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1068 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1069 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1070 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1071 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1072 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1073 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1074 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1075 */
1076__kernel void winograd_filter_transform_2x1_3x1_nchw(
1077 TENSOR4D_DECLARATION(src),
1078 TENSOR3D_DECLARATION(dst))
1079{
1080 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
1081 src_stride_x,
1082 src_step_x,
1083 src_stride_y,
1084 src_step_y,
1085 src_stride_z,
1086 src_step_z,
1087 src_stride_w,
1088 src_step_w,
1089 src_offset_first_element_in_bytes,
1090 dst_ptr,
1091 dst_stride_x,
1092 dst_step_x,
1093 dst_stride_y,
1094 dst_step_y,
1095 dst_stride_z,
1096 dst_step_z,
1097 dst_offset_first_element_in_bytes);
1098}
1099
1100/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 4x1
1101 *
1102 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1103 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001104 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001105 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001106 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001107 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1108 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1109 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1110 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1111 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1112 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1113 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1114 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1115 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1116 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1117 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1118 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1119 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1120 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1121 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1122 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1123 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1124 */
1125__kernel void winograd_filter_transform_4x1_3x1_nchw(
1126 TENSOR4D_DECLARATION(src),
1127 TENSOR3D_DECLARATION(dst))
1128{
1129 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
1130 src_stride_x,
1131 src_step_x,
1132 src_stride_y,
1133 src_step_y,
1134 src_stride_z,
1135 src_step_z,
1136 src_stride_w,
1137 src_step_w,
1138 src_offset_first_element_in_bytes,
1139 dst_ptr,
1140 dst_stride_x,
1141 dst_step_x,
1142 dst_stride_y,
1143 dst_step_y,
1144 dst_stride_z,
1145 dst_step_z,
1146 dst_offset_first_element_in_bytes);
1147}
1148
1149/** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NCHW and the output tile is 4x1
1150 *
1151 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1152 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001153 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001154 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001155 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001156 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1157 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1158 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1159 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1160 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1161 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1162 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1163 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1164 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1165 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1166 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1167 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1168 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1169 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1170 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1171 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1172 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1173 */
1174__kernel void winograd_filter_transform_4x1_5x1_nchw(
1175 TENSOR4D_DECLARATION(src),
1176 TENSOR3D_DECLARATION(dst))
1177{
1178 winograd_filter_transform_4x4_5x5_nchw(src_ptr,
1179 src_stride_x,
1180 src_step_x,
1181 src_stride_y,
1182 src_step_y,
1183 src_stride_z,
1184 src_step_z,
1185 src_stride_w,
1186 src_step_w,
1187 src_offset_first_element_in_bytes,
1188 dst_ptr,
1189 dst_stride_x,
1190 dst_step_x,
1191 dst_stride_y,
1192 dst_step_y,
1193 dst_stride_z,
1194 dst_step_z,
1195 dst_offset_first_element_in_bytes);
1196}
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001197
1198/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NHWC and the output tile is 4x1
1199 *
1200 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1201 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001202 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001203 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001204 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001205 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1206 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1207 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1208 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1209 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1210 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1211 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1212 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1213 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1214 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1215 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1216 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1217 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1218 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1219 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1220 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1221 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1222 */
1223__kernel void winograd_filter_transform_4x1_3x1_nhwc(
1224 TENSOR4D_DECLARATION(src),
1225 TENSOR3D_DECLARATION(dst))
1226{
1227 winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
1228 src_stride_x,
1229 src_step_x,
1230 src_stride_y,
1231 src_step_y,
1232 src_stride_z,
1233 src_step_z,
1234 src_stride_w,
1235 src_step_w,
1236 src_offset_first_element_in_bytes,
1237 dst_ptr,
1238 dst_stride_x,
1239 dst_step_x,
1240 dst_stride_y,
1241 dst_step_y,
1242 dst_stride_z,
1243 dst_step_z,
1244 dst_offset_first_element_in_bytes);
1245}
1246
1247/** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NHWC and the output tile is 4x1
1248 *
1249 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1250 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001251 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001252 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001253 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001254 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1255 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1256 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1257 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1258 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1259 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1260 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1261 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1262 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1263 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1264 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1265 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1266 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1267 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1268 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1269 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1270 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1271 */
1272__kernel void winograd_filter_transform_4x1_5x1_nhwc(
1273 TENSOR4D_DECLARATION(src),
1274 TENSOR3D_DECLARATION(dst))
1275{
1276 winograd_filter_transform_4x4_5x5_nhwc(src_ptr,
1277 src_stride_x,
1278 src_step_x,
1279 src_stride_y,
1280 src_step_y,
1281 src_stride_z,
1282 src_step_z,
1283 src_stride_w,
1284 src_step_w,
1285 src_offset_first_element_in_bytes,
1286 dst_ptr,
1287 dst_stride_x,
1288 dst_step_x,
1289 dst_stride_y,
1290 dst_step_y,
1291 dst_stride_z,
1292 dst_step_z,
1293 dst_offset_first_element_in_bytes);
1294}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001295#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1296
1297#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1298/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x2
1299 *
1300 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1301 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001302 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001303 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001304 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001305 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1306 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1307 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1308 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1309 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1310 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1311 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1312 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1313 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1314 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1315 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1316 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1317 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1318 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1319 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1320 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1321 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1322 */
1323__kernel void winograd_filter_transform_1x2_1x3_nchw(
1324 TENSOR4D_DECLARATION(src),
1325 TENSOR3D_DECLARATION(dst))
1326{
1327 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
1328 src_stride_x,
1329 src_step_x,
1330 src_stride_y,
1331 src_step_y,
1332 src_stride_z,
1333 src_step_z,
1334 src_stride_w,
1335 src_step_w,
1336 src_offset_first_element_in_bytes,
1337 dst_ptr,
1338 dst_stride_x,
1339 dst_step_x,
1340 dst_stride_y,
1341 dst_step_y,
1342 dst_stride_z,
1343 dst_step_z,
1344 dst_offset_first_element_in_bytes);
1345}
1346
1347/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x4
1348 *
1349 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1350 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001351 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001352 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001353 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001354 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1355 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1356 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1357 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1358 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1359 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1360 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1361 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1362 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1363 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1364 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1365 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1366 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1367 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1368 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1369 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1370 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1371 */
1372__kernel void winograd_filter_transform_1x4_1x3_nchw(
1373 TENSOR4D_DECLARATION(src),
1374 TENSOR3D_DECLARATION(dst))
1375{
1376 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
1377 src_stride_x,
1378 src_step_x,
1379 src_stride_y,
1380 src_step_y,
1381 src_stride_z,
1382 src_step_z,
1383 src_stride_w,
1384 src_step_w,
1385 src_offset_first_element_in_bytes,
1386 dst_ptr,
1387 dst_stride_x,
1388 dst_step_x,
1389 dst_stride_y,
1390 dst_step_y,
1391 dst_stride_z,
1392 dst_step_z,
1393 dst_offset_first_element_in_bytes);
1394}
1395
1396/** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NCHW and the output tile is 1x4
1397 *
1398 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1399 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001400 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001401 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001402 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001403 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1404 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1405 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1406 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1407 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1408 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1409 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1410 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1411 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1412 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1413 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1414 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1415 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1416 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1417 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1418 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1419 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1420 */
1421__kernel void winograd_filter_transform_1x4_1x5_nchw(
1422 TENSOR4D_DECLARATION(src),
1423 TENSOR3D_DECLARATION(dst))
1424{
1425 winograd_filter_transform_4x4_5x5_nchw(src_ptr,
1426 src_stride_x,
1427 src_step_x,
1428 src_stride_y,
1429 src_step_y,
1430 src_stride_z,
1431 src_step_z,
1432 src_stride_w,
1433 src_step_w,
1434 src_offset_first_element_in_bytes,
1435 dst_ptr,
1436 dst_stride_x,
1437 dst_step_x,
1438 dst_stride_y,
1439 dst_step_y,
1440 dst_stride_z,
1441 dst_step_z,
1442 dst_offset_first_element_in_bytes);
1443}
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001444
1445/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NHWC and the output tile is 1x4
1446 *
1447 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1448 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001449 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001450 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001451 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001452 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1453 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1454 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1455 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1456 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1457 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1458 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1459 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1460 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1461 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1462 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1463 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1464 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1465 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1466 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1467 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1468 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1469 */
1470__kernel void winograd_filter_transform_1x4_1x3_nhwc(
1471 TENSOR4D_DECLARATION(src),
1472 TENSOR3D_DECLARATION(dst))
1473{
1474 winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
1475 src_stride_x,
1476 src_step_x,
1477 src_stride_y,
1478 src_step_y,
1479 src_stride_z,
1480 src_step_z,
1481 src_stride_w,
1482 src_step_w,
1483 src_offset_first_element_in_bytes,
1484 dst_ptr,
1485 dst_stride_x,
1486 dst_step_x,
1487 dst_stride_y,
1488 dst_step_y,
1489 dst_stride_z,
1490 dst_step_z,
1491 dst_offset_first_element_in_bytes);
1492}
1493
1494/** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NHWC and the output tile is 1x4
1495 *
1496 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1497 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001498 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001499 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001500 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001501 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1502 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1503 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1504 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1505 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1506 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1507 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1508 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1509 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1510 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1511 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1512 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1513 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1514 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1515 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1516 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1517 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1518 */
1519__kernel void winograd_filter_transform_1x4_1x5_nhwc(
1520 TENSOR4D_DECLARATION(src),
1521 TENSOR3D_DECLARATION(dst))
1522{
1523 winograd_filter_transform_4x4_5x5_nhwc(src_ptr,
1524 src_stride_x,
1525 src_step_x,
1526 src_stride_y,
1527 src_step_y,
1528 src_stride_z,
1529 src_step_z,
1530 src_stride_w,
1531 src_step_w,
1532 src_offset_first_element_in_bytes,
1533 dst_ptr,
1534 dst_stride_x,
1535 dst_step_x,
1536 dst_stride_y,
1537 dst_step_y,
1538 dst_stride_z,
1539 dst_step_z,
1540 dst_offset_first_element_in_bytes);
1541}
Giorgio Arena149fdf32018-07-04 17:03:33 +01001542#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)