blob: 5c3bb8aa9b9e41295512da96244695b9a6a7731b [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
Michele Di Giorgiod9eaf612020-07-08 11:12:57 +01002 * Copyright (c) 2018-2019 Arm Limited.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
26#if defined(SRC_DIM_Z)
27
Michele Di Giorgio881c6842019-02-27 14:26:51 +000028#define OUTPUT_ROW_2x2_7x7(out, tmp) \
29 ({ \
30 out.s0 = -tmp.s0 / 36.f; \
31 out.s1 = (tmp.s0 - tmp.s1 + tmp.s2 - tmp.s3 + tmp.s4 - tmp.s5 + tmp.s6) / 48.f; \
32 out.s2 = (tmp.s0 + tmp.s1 + tmp.s2 + tmp.s3 + tmp.s4 + tmp.s5 + tmp.s6) / 48.f; \
33 out.s3 = (-tmp.s0 + 2.f * tmp.s1 - 4.f * tmp.s2 + 8.f * tmp.s3 - 16.f * tmp.s4 + 32.f * tmp.s5 - 64.f * tmp.s6) / 120.f; \
34 out.s4 = (-tmp.s0 - 2.f * tmp.s1 - 4.f * tmp.s2 - 8.f * tmp.s3 - 16.f * tmp.s4 - 32.f * tmp.s5 - 64.f * tmp.s6) / 120.f; \
35 out.s5 = (tmp.s0 - 3.f * tmp.s1 + 9.f * tmp.s2 - 27.f * tmp.s3 + 81.f * tmp.s4 - 243.f * tmp.s5 + 729.f * tmp.s6) / 720.f; \
36 out.s6 = (tmp.s0 + 3.f * tmp.s1 + 9.f * tmp.s2 + 27.f * tmp.s3 + 81.f * tmp.s4 + 243.f * tmp.s5 + 729.f * tmp.s6) / 720.f; \
37 out.s7 = tmp.s6; \
38 })
39
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010040/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 2x2/2x1/1x2
41 *
42 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
43 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
44 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010045 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010046 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010047 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010048 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
49 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
50 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
51 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
52 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
53 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
54 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
55 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
56 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
57 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
58 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
59 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
60 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
61 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
62 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
63 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
64 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
65 */
66__kernel void winograd_filter_transform_2x2_3x3_nchw(
67 TENSOR4D_DECLARATION(src),
68 TENSOR3D_DECLARATION(dst))
69{
70 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
71
72 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
73
74 // Load the values from the input tensor
75#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010076 VEC_DATA_TYPE(DATA_TYPE, 3)
77 w0 = vload3(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010078#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010079 VEC_DATA_TYPE(DATA_TYPE, 3)
80 w0 = (VEC_DATA_TYPE(DATA_TYPE, 3))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
81 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
82 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010083#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010084 VEC_DATA_TYPE(DATA_TYPE, 3)
85 w0 = vload3(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
86 VEC_DATA_TYPE(DATA_TYPE, 3)
87 w1 = vload3(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
88 VEC_DATA_TYPE(DATA_TYPE, 3)
89 w2 = vload3(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010090#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
91
92 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010093 VEC_DATA_TYPE(DATA_TYPE, 4)
94 out0 = 0.0f;
95 out0.s0 = (w0.s0);
96 out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
97 out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
98 out0.s3 = (w0.s2);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010099
100#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
101 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100102 VEC_DATA_TYPE(DATA_TYPE, 4)
103 out1 = 0.0f;
104 out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
105 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
106 out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
107 out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100108
109 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100110 VEC_DATA_TYPE(DATA_TYPE, 4)
111 out2 = 0.0f;
112 out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
113 out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
114 out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
115 out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100116
117 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100118 VEC_DATA_TYPE(DATA_TYPE, 4)
119 out3 = 0.0f;
120 out3.s0 = (w2.s0);
121 out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
122 out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
123 out3.s3 = (w2.s2);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100124#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
125
126 int z = get_global_id(2);
127 int x0 = z / SRC_DIM_Z; // idx filter
128 int y0 = z % SRC_DIM_Z; // idx channel
129
130 // Get output address
131 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
132
133 // Store the values across the channels
134 // 16 channels for 3x3 kernels
135 // 4 channels for 3x1 or 1x3 kernels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100136 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
137 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
138 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
139 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100140
141#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100142 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out1.s0;
143 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out1.s1;
144 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out1.s2;
145 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out1.s3;
146 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out2.s0;
147 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out2.s1;
148 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out2.s2;
149 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out2.s3;
150 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out3.s0;
151 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out3.s1;
152 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out3.s2;
153 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out3.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100154#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
155}
156
157/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 4x4/4x1/1x4
158 *
159 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
160 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
161 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100162 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100163 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100164 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100165 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
166 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
167 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
168 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
169 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
170 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
171 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
172 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
173 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
174 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
175 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
176 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
177 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
178 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
179 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
180 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
181 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
182 */
183__kernel void winograd_filter_transform_4x4_3x3_nchw(
184 TENSOR4D_DECLARATION(src),
185 TENSOR3D_DECLARATION(dst))
186{
187 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
188
189 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
190
191 // Load the values from the input tensor
192#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100193 VEC_DATA_TYPE(DATA_TYPE, 3)
194 w0 = vload3(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100195#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100196 VEC_DATA_TYPE(DATA_TYPE, 3)
197 w0 = (VEC_DATA_TYPE(DATA_TYPE, 3))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
198 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
199 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100200#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100201 VEC_DATA_TYPE(DATA_TYPE, 3)
202 w0 = vload3(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
203 VEC_DATA_TYPE(DATA_TYPE, 3)
204 w1 = vload3(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
205 VEC_DATA_TYPE(DATA_TYPE, 3)
206 w2 = vload3(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100207#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
208
209 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100210 VEC_DATA_TYPE(DATA_TYPE, 8)
211 out0 = 0.0f;
212 out0.s0 = (w0.s0) / 16.f;
213 out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
214 out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
215 out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
216 out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
217 out0.s5 = (w0.s2) / 4.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100218
219#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
220 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100221 VEC_DATA_TYPE(DATA_TYPE, 8)
222 out1 = 0.0f;
223 out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
224 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
225 out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
226 out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
227 out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
228 out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100229
230 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100231 VEC_DATA_TYPE(DATA_TYPE, 8)
232 out2 = 0.0f;
233 out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
234 out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
235 out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
236 out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
237 out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
238 out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100239
240 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100241 VEC_DATA_TYPE(DATA_TYPE, 8)
242 out3 = 0.0f;
243 out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
244 out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
245 out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
246 out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
247 out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
248 out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100249
250 // Row 4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100251 VEC_DATA_TYPE(DATA_TYPE, 8)
252 out4 = 0.0f;
253 out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
254 out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
255 out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
256 out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
257 out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
258 out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100259
260 // Row 5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100261 VEC_DATA_TYPE(DATA_TYPE, 8)
262 out5 = 0.0f;
263 out5.s0 = (w2.s0) / 4.f;
264 out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
265 out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
266 out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
267 out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
268 out5.s5 = (w2.s2);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100269#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
270
271 int z = get_global_id(2);
272 int x0 = z / SRC_DIM_Z; // idx filter
273 int y0 = z % SRC_DIM_Z; // idx channel
274
275 // Get output address
276 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
277
278 // Store the values across the channels
279 // 36 channels for 3x3 kernels
280 // 6 channels for 3x1 or 1x3 kernels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100281 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
282 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
283 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
284 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
285 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out0.s4;
286 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out0.s5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100287
288#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100289 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out1.s0;
290 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out1.s1;
291 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out1.s2;
292 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out1.s3;
293 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out1.s4;
294 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out1.s5;
295 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out2.s0;
296 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out2.s1;
297 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out2.s2;
298 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out2.s3;
299 *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out2.s4;
300 *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out2.s5;
301 *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out3.s0;
302 *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out3.s1;
303 *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out3.s2;
304 *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out3.s3;
305 *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out3.s4;
306 *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out3.s5;
307 *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out4.s0;
308 *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out4.s1;
309 *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out4.s2;
310 *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out4.s3;
311 *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out4.s4;
312 *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out4.s5;
313 *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out5.s0;
314 *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out5.s1;
315 *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out5.s2;
316 *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out5.s3;
317 *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out5.s4;
318 *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out5.s5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100319#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
320}
321
Giorgio Arena149fdf32018-07-04 17:03:33 +0100322/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NHWC and the output tile is 4x4/4x1/1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100323 *
324 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
Giorgio Arena149fdf32018-07-04 17:03:33 +0100325 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
326 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100327 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100328 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100329 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100330 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
331 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
332 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
333 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
334 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
335 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
336 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
337 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
338 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
339 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
340 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
341 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
342 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
343 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
344 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
345 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
346 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
347 */
348__kernel void winograd_filter_transform_4x4_3x3_nhwc(
349 TENSOR4D_DECLARATION(src),
350 TENSOR3D_DECLARATION(dst))
351{
352 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
353
354 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
355
356 // Load the values from the input tensor
Giorgio Arena149fdf32018-07-04 17:03:33 +0100357#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100358 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
359 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
360 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
Giorgio Arena149fdf32018-07-04 17:03:33 +0100361#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100362 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
363 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
364 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
Giorgio Arena149fdf32018-07-04 17:03:33 +0100365#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100366 DATA_TYPE w10 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
367 DATA_TYPE w11 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
368 DATA_TYPE w12 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
369 DATA_TYPE w20 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
370 DATA_TYPE w21 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
371 DATA_TYPE w22 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
Giorgio Arena149fdf32018-07-04 17:03:33 +0100372#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
373#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100374
375 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100376 DATA_TYPE out00, out01, out02, out03, out04, out05;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100377 out00 = (w00) / 16.f;
378 out01 = (-w00 - w01 - w02) / 24.f;
379 out02 = (-w00 + w01 - w02) / 24.f;
380 out03 = (w00 + 2.f * w01 + 4.f * w02) / 96.f;
381 out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f;
382 out05 = (w02) / 4.f;
383
Giorgio Arena149fdf32018-07-04 17:03:33 +0100384#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100385 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100386 DATA_TYPE out10, out11, out12, out13, out14, out15;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100387 out10 = (-w00 - w10 - w20) / 24.f;
388 out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f;
389 out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f;
390 out13 = (-w00 - w10 - w20 + 2.f * (-w01 - w11 - w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
391 out14 = (-w00 - w10 - w20 + 2.f * (w01 + w11 + w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
392 out15 = (-w02 - w12 - w22) / 6.f;
393
394 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100395 DATA_TYPE out20, out21, out22, out23, out24, out25;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100396 out20 = (-w00 + w10 - w20) / 24.f;
397 out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f;
398 out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f;
399 out23 = (-w00 + w10 - w20 + 2.f * (-w01 + w11 - w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
400 out24 = (-w00 + w10 - w20 + 2.f * (w01 - w11 + w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
401 out25 = (-w02 + w12 - w22) / 6.f;
402
403 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100404 DATA_TYPE out30, out31, out32, out33, out34, out35;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100405 out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f;
406 out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
407 out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
408 out33 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (w01 + 2.f * w11 + 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
409 out34 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (-w01 - 2.f * w11 - 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
410 out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f;
411
412 // Row 4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100413 DATA_TYPE out40, out41, out42, out43, out44, out45;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100414 out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f;
415 out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
416 out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
417 out43 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (w01 - 2.f * w11 + 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
418 out44 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (-w01 + 2.f * w11 - 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
419 out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f;
420
421 // Row 5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100422 DATA_TYPE out50, out51, out52, out53, out54, out55;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100423 out50 = (w20) / 4.f;
424 out51 = (-w20 - w21 - w22) / 6.f;
425 out52 = (-w20 + w21 - w22) / 6.f;
426 out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f;
427 out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f;
428 out55 = (w22);
Giorgio Arena149fdf32018-07-04 17:03:33 +0100429#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100430
431 int x0 = get_global_id(2); // idx filter
432 int y0 = get_global_id(0); // idx channel
433
434 // Get output address
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100435 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(DATA_TYPE) + y0 * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100436
437 // Store the values across the channels
Giorgio Arena149fdf32018-07-04 17:03:33 +0100438 // 36 channels for 3x3 kernels
439 // 6 channels for 3x1 or 1x3 kernels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100440 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out00;
441 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out01;
442 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out02;
443 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out03;
444 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out04;
445 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out05;
Giorgio Arena149fdf32018-07-04 17:03:33 +0100446#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100447 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out10;
448 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out11;
449 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out12;
450 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out13;
451 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out14;
452 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out15;
453 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out20;
454 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out21;
455 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out22;
456 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out23;
457 *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out24;
458 *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out25;
459 *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out30;
460 *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out31;
461 *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out32;
462 *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out33;
463 *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out34;
464 *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out35;
465 *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out40;
466 *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out41;
467 *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out42;
468 *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out43;
469 *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out44;
470 *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out45;
471 *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out50;
472 *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out51;
473 *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out52;
474 *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out53;
475 *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out54;
476 *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out55;
Giorgio Arena149fdf32018-07-04 17:03:33 +0100477#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100478}
Giorgio Arena149fdf32018-07-04 17:03:33 +0100479
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100480/** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NCHW and the output tile is 4x4/4x1 or 1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100481 *
482 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
483 *
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100484 * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
485 * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100486 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100487 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100488 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100489 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
490 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
491 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
492 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
493 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
494 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
495 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
496 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
497 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
498 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
499 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
500 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
501 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
502 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
503 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
504 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
505 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
506 */
507__kernel void winograd_filter_transform_4x4_5x5_nchw(
508 TENSOR4D_DECLARATION(src),
509 TENSOR3D_DECLARATION(dst))
510{
511 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
512
513 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
514
515 // Load the values from the input tensor
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100516#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100517 VEC_DATA_TYPE(DATA_TYPE, 4)
518 w00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
519 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_y) + 4);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100520#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100521 VEC_DATA_TYPE(DATA_TYPE, 4)
522 w00 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
523 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
524 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
525 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
526 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100527#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100528 VEC_DATA_TYPE(DATA_TYPE, 4)
529 w00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
530 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_y) + 4);
531 VEC_DATA_TYPE(DATA_TYPE, 4)
532 w10 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
533 DATA_TYPE w11 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y) + 4);
534 VEC_DATA_TYPE(DATA_TYPE, 4)
535 w20 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
536 DATA_TYPE w21 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y) + 4);
537 VEC_DATA_TYPE(DATA_TYPE, 4)
538 w30 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
539 DATA_TYPE w31 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y) + 4);
540 VEC_DATA_TYPE(DATA_TYPE, 4)
541 w40 = vload4(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
542 DATA_TYPE w41 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y) + 4);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100543#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100544
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100545 // Transform the input tile
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100546
547 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100548 VEC_DATA_TYPE(DATA_TYPE, 8)
549 out0 = 0.0f;
550 out0.s0 = w00.s0;
551 out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
552 out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
553 out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
554 out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
555 out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
556 out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
557 out0.s7 = w01;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100558
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100559#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100560 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100561 VEC_DATA_TYPE(DATA_TYPE, 8)
562 out1 = 0.0f;
563 out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
564 out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
565 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
566 out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
567 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
568 out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
569 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
570 out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
571 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
572 out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
573 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
574 out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
575 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
576 out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100577
578 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100579 VEC_DATA_TYPE(DATA_TYPE, 8)
580 out2 = 0.0f;
581 out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
582 out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
583 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
584 out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
585 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
586 out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
587 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
588 out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
589 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
590 out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
591 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
592 out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
593 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
594 out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100595
596 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100597 VEC_DATA_TYPE(DATA_TYPE, 8)
598 out3 = 0.0f;
599 out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
600 out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
601 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
602 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
603 out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
604 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
605 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
606 out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
607 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
608 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
609 out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
610 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
611 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
612 out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
613 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
614 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
615 out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
616 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
617 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
618 out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100619
620 // Row 4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100621 VEC_DATA_TYPE(DATA_TYPE, 8)
622 out4 = 0.0f;
623 out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
624 out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
625 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
626 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
627 out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
628 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
629 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
630 out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
631 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
632 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
633 out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
634 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
635 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
636 out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
637 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
638 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
639 out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
640 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
641 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
642 out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100643
644 // Row 5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100645 VEC_DATA_TYPE(DATA_TYPE, 8)
646 out5 = 0.0f;
647 out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
648 out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
649 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
650 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
651 out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
652 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
653 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
654 out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
655 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
656 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
657 out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
658 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
659 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
660 out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
661 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
662 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
663 out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
664 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
665 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
666 out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100667
668 // Row 6
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100669 VEC_DATA_TYPE(DATA_TYPE, 8)
670 out6 = 0.0f;
671 out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
672 out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
673 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
674 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
675 out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
676 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
677 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
678 out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
679 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
680 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
681 out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
682 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
683 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
684 out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
685 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
686 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
687 out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
688 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
689 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
690 out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100691
692 // Row 7
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100693 VEC_DATA_TYPE(DATA_TYPE, 8)
694 out7 = 0.0f;
695 out7.s0 = w40.s0;
696 out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
697 out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
698 out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
699 out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
700 out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
701 out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
702 out7.s7 = w41;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100703#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100704
705 int z = get_global_id(2);
706 int x0 = z / SRC_DIM_Z; // idx filter
707 int y0 = z % SRC_DIM_Z; // idx channel
708
709 // Get output address
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100710 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(DATA_TYPE) + y0 * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100711
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100712 // Store the values across the channels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100713 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
714 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
715 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
716 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
717 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out0.s4;
718 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out0.s5;
719 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out0.s6;
720 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out0.s7;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100721
722#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100723 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out1.s0;
724 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out1.s1;
725 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out1.s2;
726 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out1.s3;
727 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out1.s4;
728 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out1.s5;
729 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out1.s6;
730 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out1.s7;
731 *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out2.s0;
732 *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out2.s1;
733 *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out2.s2;
734 *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out2.s3;
735 *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out2.s4;
736 *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out2.s5;
737 *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out2.s6;
738 *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out2.s7;
739 *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out3.s0;
740 *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out3.s1;
741 *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out3.s2;
742 *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out3.s3;
743 *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out3.s4;
744 *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out3.s5;
745 *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out3.s6;
746 *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out3.s7;
747 *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out4.s0;
748 *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out4.s1;
749 *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out4.s2;
750 *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out4.s3;
751 *(__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z) = out4.s4;
752 *(__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z) = out4.s5;
753 *(__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z) = out4.s6;
754 *(__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z) = out4.s7;
755 *(__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z) = out5.s0;
756 *(__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z) = out5.s1;
757 *(__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z) = out5.s2;
758 *(__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z) = out5.s3;
759 *(__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z) = out5.s4;
760 *(__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z) = out5.s5;
761 *(__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z) = out5.s6;
762 *(__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z) = out5.s7;
763 *(__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z) = out6.s0;
764 *(__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z) = out6.s1;
765 *(__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z) = out6.s2;
766 *(__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z) = out6.s3;
767 *(__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z) = out6.s4;
768 *(__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z) = out6.s5;
769 *(__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z) = out6.s6;
770 *(__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z) = out6.s7;
771 *(__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z) = out7.s0;
772 *(__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z) = out7.s1;
773 *(__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z) = out7.s2;
774 *(__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z) = out7.s3;
775 *(__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z) = out7.s4;
776 *(__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z) = out7.s5;
777 *(__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z) = out7.s6;
778 *(__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z) = out7.s7;
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100779#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100780}
781
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100782/** This OpenCL kernel performs Winograd filter transform 5x5/5x1 or 1x5 when the data layout is NHWC and the output tile is 4x4/4x1 or 1x4
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100783 *
784 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100785 * @note If this kernel is used to perform Winograd filter transform 5x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
786 * @note If this kernel is used to perform Winograd filter transform 1x5, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100787 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100788 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100789 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100790 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
791 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
792 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
793 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
794 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
795 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
796 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
797 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
798 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
799 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
800 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
801 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
802 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
803 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
804 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
805 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
806 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
807 */
808__kernel void winograd_filter_transform_4x4_5x5_nhwc(
809 TENSOR4D_DECLARATION(src),
810 TENSOR3D_DECLARATION(dst))
811{
812 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
813
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100814 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(DATA_TYPE) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100815
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100816#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100817 // Load the values from the input tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100818 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
819 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
820 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
821 DATA_TYPE w03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
822 DATA_TYPE w04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100823#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
824 // Load the values from the input tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100825 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
826 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
827 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
828 DATA_TYPE w03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
829 DATA_TYPE w04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100830#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
831
832#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100833 DATA_TYPE w10 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
834 DATA_TYPE w11 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
835 DATA_TYPE w12 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
836 DATA_TYPE w13 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 3 * src_stride_y));
837 DATA_TYPE w14 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 4 * src_stride_y));
838 DATA_TYPE w20 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
839 DATA_TYPE w21 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
840 DATA_TYPE w22 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
841 DATA_TYPE w23 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 3 * src_stride_y));
842 DATA_TYPE w24 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 4 * src_stride_y));
843 DATA_TYPE w30 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 0 * src_stride_y));
844 DATA_TYPE w31 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 1 * src_stride_y));
845 DATA_TYPE w32 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 2 * src_stride_y));
846 DATA_TYPE w33 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 3 * src_stride_y));
847 DATA_TYPE w34 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 4 * src_stride_y));
848 DATA_TYPE w40 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 0 * src_stride_y));
849 DATA_TYPE w41 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 1 * src_stride_y));
850 DATA_TYPE w42 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 2 * src_stride_y));
851 DATA_TYPE w43 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 3 * src_stride_y));
852 DATA_TYPE w44 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 4 * src_stride_y));
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100853#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100854
855 // Row 0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100856 VEC_DATA_TYPE(DATA_TYPE, 8)
857 out0 = 0.0f;
858 out0.s0 = w00;
859 out0.s1 = -2.f * (w00 + w01 + w02 + w03 + w04) / 9.f;
860 out0.s2 = -2.f * (w00 - w01 + w02 - w03 + w04) / 9.f;
861 out0.s3 = (w00 + 2.f * w01 + 4.f * w02 + 8.f * w03 + 16.f * w04) / 90.f;
862 out0.s4 = (w00 - 2.f * w01 + 4.f * w02 - 8.f * w03 + 16.f * w04) / 90.f;
863 out0.s5 = (16.f * w00 + 8.f * w01 + 4.f * w02 + 2.f * w03 + w04) / 180.f;
864 out0.s6 = (16.f * w00 - 8.f * w01 + 4.f * w02 - 2.f * w03 + w04) / 180.f;
865 out0.s7 = w04;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100866
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100867#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100868 // Row 1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100869 VEC_DATA_TYPE(DATA_TYPE, 8)
870 out1 = 0.0f;
871 out1.s0 = -2.f * (w00 + w10 + w20 + w30 + w40) / 9.f;
872 out1.s1 = 4.f * ((w00 + w10 + w20 + w30 + w40) + (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) + (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
873 out1.s2 = 4.f * ((w00 + w10 + w20 + w30 + w40) - (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) - (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
874 out1.s3 = -((w00 + w10 + w20 + w30 + w40) + 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
875 (w04 + w14 + w24 + w34 + w44)) / 405.f;
876 out1.s4 = -((w00 + w10 + w20 + w30 + w40) - 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
877 (w04 + w14 + w24 + w34 + w44)) / 405.f;
878 out1.s5 = -(16.f * (w00 + w10 + w20 + w30 + w40) + 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 2.f * (w03 + w13 + w23 + w33 + w43) +
879 (w04 + w14 + w24 + w34 + w44)) / 810.f;
880 out1.s6 = -(16.f * (w00 + w10 + w20 + w30 + w40) - 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 2.f * (w03 + w13 + w23 + w33 + w43) +
881 (w04 + w14 + w24 + w34 + w44)) / 810.f;
882 out1.s7 = -2.f * (w04 + w14 + w24 + w34 + w44) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100883
884 // Row 2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100885 VEC_DATA_TYPE(DATA_TYPE, 8)
886 out2 = 0.0f;
887 out2.s0 = -2.f * (w00 - w10 + w20 - w30 + w40) / 9.f;
888 out2.s1 = 4.f * ((w00 - w10 + w20 - w30 + w40) + (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) + (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
889 out2.s2 = 4.f * ((w00 - w10 + w20 - w30 + w40) - (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) - (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
890 out2.s3 = -((w00 - w10 + w20 - w30 + w40) + 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
891 (w04 - w14 + w24 - w34 + w44)) / 405.f;
892 out2.s4 = -((w00 - w10 + w20 - w30 + w40) - 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
893 (w04 - w14 + w24 - w34 + w44)) / 405.f;
894 out2.s5 = -(16.f * (w00 - w10 + w20 - w30 + w40) + 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 2.f * (w03 - w13 + w23 - w33 + w43) +
895 (w04 - w14 + w24 - w34 + w44)) / 810.f;
896 out2.s6 = -(16.f * (w00 - w10 + w20 - w30 + w40) - 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 2.f * (w03 - w13 + w23 - w33 + w43) +
897 (w04 - w14 + w24 - w34 + w44)) / 810.f;
898 out2.s7 = -2.f * (w04 - w14 + w24 - w34 + w44) / 9.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100899
900 // Row 3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100901 VEC_DATA_TYPE(DATA_TYPE, 8)
902 out3 = 0.0f;
903 out3.s0 = (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) / 90.f;
904 out3.s1 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) +
905 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
906 out3.s2 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) -
907 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
908 out3.s3 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 8.f
909 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
910 out3.s4 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 8.f
911 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
912 out3.s5 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
913 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
914 out3.s6 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
915 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
916 out3.s7 = (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100917
918 // Row 4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100919 VEC_DATA_TYPE(DATA_TYPE, 8)
920 out4 = 0.0f;
921 out4.s0 = (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) / 90.f;
922 out4.s1 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) +
923 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
924 out4.s2 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) -
925 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
926 out4.s3 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 8.f
927 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
928 out4.s4 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 8.f
929 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
930 out4.s5 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
931 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
932 out4.s6 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
933 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
934 out4.s7 = (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44) / 90.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100935
936 // Row 5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100937 VEC_DATA_TYPE(DATA_TYPE, 8)
938 out5 = 0.0f;
939 out5.s0 = (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) / 180.f;
940 out5.s1 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) +
941 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
942 out5.s2 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) -
943 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
944 out5.s3 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 8.f
945 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
946 out5.s4 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 8.f
947 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
948 out5.s5 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
949 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
950 out5.s6 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
951 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
952 out5.s7 = (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100953
954 // Row 6
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100955 VEC_DATA_TYPE(DATA_TYPE, 8)
956 out6 = 0.0f;
957 out6.s0 = (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) / 180.f;
958 out6.s1 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) +
959 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
960 out6.s2 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) -
961 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
962 out6.s3 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 8.f
963 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
964 out6.s4 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 8.f
965 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
966 out6.s5 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
967 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
968 out6.s6 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
969 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
970 out6.s7 = (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44) / 180.f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100971
972 // Row 7
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100973 VEC_DATA_TYPE(DATA_TYPE, 8)
974 out7 = 0.0f;
975 out7.s0 = w40;
976 out7.s1 = -2.f * (w40 + w41 + w42 + w43 + w44) / 9.f;
977 out7.s2 = -2.f * (w40 - w41 + w42 - w43 + w44) / 9.f;
978 out7.s3 = (w40 + 2.f * w41 + 4.f * w42 + 8.f * w43 + 16.f * w44) / 90.f;
979 out7.s4 = (w40 - 2.f * w41 + 4.f * w42 - 8.f * w43 + 16.f * w44) / 90.f;
980 out7.s5 = (16.f * w40 + 8.f * w41 + 4.f * w42 + 2.f * w43 + w44) / 180.f;
981 out7.s6 = (16.f * w40 - 8.f * w41 + 4.f * w42 - 2.f * w43 + w44) / 180.f;
982 out7.s7 = w44;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100983#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100984
985 int x0 = get_global_id(2); // idx filter
986 int y0 = get_global_id(0); // idx channel
987
988 // Get output address
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100989 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(DATA_TYPE) + y0 * dst_stride_y;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100990
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100991 // Store the values across the channels
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100992 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
993 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
994 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
995 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
996 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out0.s4;
997 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out0.s5;
998 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out0.s6;
999 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out0.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001000
1001#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001002 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out1.s0;
1003 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out1.s1;
1004 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out1.s2;
1005 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out1.s3;
1006 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out1.s4;
1007 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out1.s5;
1008 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out1.s6;
1009 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out1.s7;
1010 *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out2.s0;
1011 *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out2.s1;
1012 *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out2.s2;
1013 *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out2.s3;
1014 *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out2.s4;
1015 *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out2.s5;
1016 *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out2.s6;
1017 *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out2.s7;
1018 *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out3.s0;
1019 *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out3.s1;
1020 *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out3.s2;
1021 *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out3.s3;
1022 *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out3.s4;
1023 *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out3.s5;
1024 *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out3.s6;
1025 *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out3.s7;
1026 *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out4.s0;
1027 *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out4.s1;
1028 *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out4.s2;
1029 *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out4.s3;
1030 *(__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z) = out4.s4;
1031 *(__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z) = out4.s5;
1032 *(__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z) = out4.s6;
1033 *(__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z) = out4.s7;
1034 *(__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z) = out5.s0;
1035 *(__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z) = out5.s1;
1036 *(__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z) = out5.s2;
1037 *(__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z) = out5.s3;
1038 *(__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z) = out5.s4;
1039 *(__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z) = out5.s5;
1040 *(__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z) = out5.s6;
1041 *(__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z) = out5.s7;
1042 *(__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z) = out6.s0;
1043 *(__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z) = out6.s1;
1044 *(__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z) = out6.s2;
1045 *(__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z) = out6.s3;
1046 *(__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z) = out6.s4;
1047 *(__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z) = out6.s5;
1048 *(__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z) = out6.s6;
1049 *(__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z) = out6.s7;
1050 *(__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z) = out7.s0;
1051 *(__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z) = out7.s1;
1052 *(__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z) = out7.s2;
1053 *(__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z) = out7.s3;
1054 *(__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z) = out7.s4;
1055 *(__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z) = out7.s5;
1056 *(__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z) = out7.s6;
1057 *(__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z) = out7.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001058#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001059}
Michele Di Giorgio881c6842019-02-27 14:26:51 +00001060/** This OpenCL kernel performs Winograd filter transform 7x7/7x1 or 1x7 when the data layout is NHWC and the output tile is 2x2/2x1 or 1x2
1061 *
1062 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1063 * @note If this kernel is used to perform Winograd filter transform 7x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
1064 * @note If this kernel is used to perform Winograd filter transform 1x7, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
1065 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1066 *
1067 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
1068 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1069 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1070 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1071 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1072 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1073 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1074 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1075 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1076 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1077 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1078 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1079 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1080 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1081 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1082 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1083 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1084 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1085 */
1086__kernel void winograd_filter_transform_2x2_7x7_nhwc(
1087 TENSOR4D_DECLARATION(src),
1088 TENSOR3D_DECLARATION(dst))
1089{
1090 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
1091
1092 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(DATA_TYPE) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
1093
1094#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1095 // Load the values from the input tensor
1096 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_z));
1097 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z));
1098 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z));
1099 DATA_TYPE w03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z));
1100 DATA_TYPE w04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z));
1101 DATA_TYPE w05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z));
1102 DATA_TYPE w06 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z));
1103#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1104 // Load the values from the input tensor
1105 DATA_TYPE w00 = *((__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
1106 DATA_TYPE w01 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
1107 DATA_TYPE w02 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
1108 DATA_TYPE w03 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
1109 DATA_TYPE w04 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
1110 DATA_TYPE w05 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
1111 DATA_TYPE w06 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_y));
1112#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1113
1114#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1115 DATA_TYPE w10 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
1116 DATA_TYPE w11 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
1117 DATA_TYPE w12 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
1118 DATA_TYPE w13 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 3 * src_stride_y));
1119 DATA_TYPE w14 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 4 * src_stride_y));
1120 DATA_TYPE w15 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 5 * src_stride_y));
1121 DATA_TYPE w16 = *((__global DATA_TYPE *)(src_addr + 1 * src_stride_z + 6 * src_stride_y));
1122
1123 DATA_TYPE w20 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
1124 DATA_TYPE w21 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
1125 DATA_TYPE w22 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
1126 DATA_TYPE w23 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 3 * src_stride_y));
1127 DATA_TYPE w24 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 4 * src_stride_y));
1128 DATA_TYPE w25 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 5 * src_stride_y));
1129 DATA_TYPE w26 = *((__global DATA_TYPE *)(src_addr + 2 * src_stride_z + 6 * src_stride_y));
1130
1131 DATA_TYPE w30 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 0 * src_stride_y));
1132 DATA_TYPE w31 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 1 * src_stride_y));
1133 DATA_TYPE w32 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 2 * src_stride_y));
1134 DATA_TYPE w33 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 3 * src_stride_y));
1135 DATA_TYPE w34 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 4 * src_stride_y));
1136 DATA_TYPE w35 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 5 * src_stride_y));
1137 DATA_TYPE w36 = *((__global DATA_TYPE *)(src_addr + 3 * src_stride_z + 6 * src_stride_y));
1138
1139 DATA_TYPE w40 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 0 * src_stride_y));
1140 DATA_TYPE w41 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 1 * src_stride_y));
1141 DATA_TYPE w42 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 2 * src_stride_y));
1142 DATA_TYPE w43 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 3 * src_stride_y));
1143 DATA_TYPE w44 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 4 * src_stride_y));
1144 DATA_TYPE w45 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 5 * src_stride_y));
1145 DATA_TYPE w46 = *((__global DATA_TYPE *)(src_addr + 4 * src_stride_z + 6 * src_stride_y));
1146
1147 DATA_TYPE w50 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z + 0 * src_stride_y));
1148 DATA_TYPE w51 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z + 1 * src_stride_y));
1149 DATA_TYPE w52 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z + 2 * src_stride_y));
1150 DATA_TYPE w53 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z + 3 * src_stride_y));
1151 DATA_TYPE w54 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z + 4 * src_stride_y));
1152 DATA_TYPE w55 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z + 5 * src_stride_y));
1153 DATA_TYPE w56 = *((__global DATA_TYPE *)(src_addr + 5 * src_stride_z + 6 * src_stride_y));
1154
1155 DATA_TYPE w60 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z + 0 * src_stride_y));
1156 DATA_TYPE w61 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z + 1 * src_stride_y));
1157 DATA_TYPE w62 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z + 2 * src_stride_y));
1158 DATA_TYPE w63 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z + 3 * src_stride_y));
1159 DATA_TYPE w64 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z + 4 * src_stride_y));
1160 DATA_TYPE w65 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z + 5 * src_stride_y));
1161 DATA_TYPE w66 = *((__global DATA_TYPE *)(src_addr + 6 * src_stride_z + 6 * src_stride_y));
1162
1163#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1164
1165 VEC_DATA_TYPE(DATA_TYPE, 8)
1166 tmp = 0.0f;
1167
1168 // Row 0
1169 VEC_DATA_TYPE(DATA_TYPE, 8)
1170 out0 = 0.0f;
1171
1172 out0.s0 = -w00 / 36.0f;
1173 out0.s1 = (w00 - w01 + w02 - w03 + w04 - w05 + w06) / 48.f;
1174 out0.s2 = (w00 + w01 + w02 + w03 + w04 + w05 + w06) / 48.f;
1175 out0.s3 = (-w00 + 2.f * w01 - 4.f * w02 + 8.f * w03 - 16.f * w04 + 32.f * w05 - 64.f * w06) / 120.f;
1176 out0.s4 = (-w00 - 2.f * w01 - 4.f * w02 - 8.f * w03 - 16.f * w04 - 32.f * w05 - 64.f * w06) / 120.f;
1177 out0.s5 = (w00 - 3.f * w01 + 9.f * w02 - 27.f * w03 + 81.f * w04 - 243.f * w05 + 729.f * w06) / 720.f;
1178 out0.s6 = (w00 + 3.f * w01 + 9.f * w02 + 27.f * w03 + 81.f * w04 + 243.f * w05 + 729.f * w06) / 720.f;
1179 out0.s7 = w06;
1180
1181 out0 /= (VEC_DATA_TYPE(DATA_TYPE, 8)) - 36.f;
1182
1183#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1184
1185 // Row 1
1186 VEC_DATA_TYPE(DATA_TYPE, 8)
1187 out1 = 0.0f;
1188
1189 tmp.s0 = (w00 - w10 + w20 - w30 + w40 - w50 + w60) / 48.f;
1190 tmp.s1 = (w01 - w11 + w21 - w31 + w41 - w51 + w61) / 48.f;
1191 tmp.s2 = (w02 - w12 + w22 - w32 + w42 - w52 + w62) / 48.f;
1192 tmp.s3 = (w03 - w13 + w23 - w33 + w43 - w53 + w63) / 48.f;
1193 tmp.s4 = (w04 - w14 + w24 - w34 + w44 - w54 + w64) / 48.f;
1194 tmp.s5 = (w05 - w15 + w25 - w35 + w45 - w55 + w65) / 48.f;
1195 tmp.s6 = (w06 - w16 + w26 - w36 + w46 - w56 + w66) / 48.f;
1196
1197 OUTPUT_ROW_2x2_7x7(out1, tmp);
1198
1199 // Row 2
1200 VEC_DATA_TYPE(DATA_TYPE, 8)
1201 out2 = 0.0f;
1202
1203 tmp.s0 = (w00 + w10 + w20 + w30 + w40 + w50 + w60) / 48.f;
1204 tmp.s1 = (w01 + w11 + w21 + w31 + w41 + w51 + w61) / 48.f;
1205 tmp.s2 = (w02 + w12 + w22 + w32 + w42 + w52 + w62) / 48.f;
1206 tmp.s3 = (w03 + w13 + w23 + w33 + w43 + w53 + w63) / 48.f;
1207 tmp.s4 = (w04 + w14 + w24 + w34 + w44 + w54 + w64) / 48.f;
1208 tmp.s5 = (w05 + w15 + w25 + w35 + w45 + w55 + w65) / 48.f;
1209 tmp.s6 = (w06 + w16 + w26 + w36 + w46 + w56 + w66) / 48.f;
1210
1211 OUTPUT_ROW_2x2_7x7(out2, tmp);
1212
1213 // Row 3
1214 VEC_DATA_TYPE(DATA_TYPE, 8)
1215 out3 = 0.0f;
1216
1217 tmp.s0 = (-w00 + 2.f * w10 - 4.f * w20 + 8.f * w30 - 16.f * w40 + 32.f * w50 - 64.f * w60) / 120.f;
1218 tmp.s1 = (-w01 + 2.f * w11 - 4.f * w21 + 8.f * w31 - 16.f * w41 + 32.f * w51 - 64.f * w61) / 120.f;
1219 tmp.s2 = (-w02 + 2.f * w12 - 4.f * w22 + 8.f * w32 - 16.f * w42 + 32.f * w52 - 64.f * w62) / 120.f;
1220 tmp.s3 = (-w03 + 2.f * w13 - 4.f * w23 + 8.f * w33 - 16.f * w43 + 32.f * w53 - 64.f * w63) / 120.f;
1221 tmp.s4 = (-w04 + 2.f * w14 - 4.f * w24 + 8.f * w34 - 16.f * w44 + 32.f * w54 - 64.f * w64) / 120.f;
1222 tmp.s5 = (-w05 + 2.f * w15 - 4.f * w25 + 8.f * w35 - 16.f * w45 + 32.f * w55 - 64.f * w65) / 120.f;
1223 tmp.s6 = (-w06 + 2.f * w16 - 4.f * w26 + 8.f * w36 - 16.f * w46 + 32.f * w56 - 64.f * w66) / 120.f;
1224
1225 OUTPUT_ROW_2x2_7x7(out3, tmp);
1226
1227 // Row 4
1228 VEC_DATA_TYPE(DATA_TYPE, 8)
1229 out4 = 0.0f;
1230
1231 tmp.s0 = (-w00 - 2.f * w10 - 4.f * w20 - 8.f * w30 - 16.f * w40 - 32.f * w50 - 64.f * w60) / 120.f;
1232 tmp.s1 = (-w01 - 2.f * w11 - 4.f * w21 - 8.f * w31 - 16.f * w41 - 32.f * w51 - 64.f * w61) / 120.f;
1233 tmp.s2 = (-w02 - 2.f * w12 - 4.f * w22 - 8.f * w32 - 16.f * w42 - 32.f * w52 - 64.f * w62) / 120.f;
1234 tmp.s3 = (-w03 - 2.f * w13 - 4.f * w23 - 8.f * w33 - 16.f * w43 - 32.f * w53 - 64.f * w63) / 120.f;
1235 tmp.s4 = (-w04 - 2.f * w14 - 4.f * w24 - 8.f * w34 - 16.f * w44 - 32.f * w54 - 64.f * w64) / 120.f;
1236 tmp.s5 = (-w05 - 2.f * w15 - 4.f * w25 - 8.f * w35 - 16.f * w45 - 32.f * w55 - 64.f * w65) / 120.f;
1237 tmp.s6 = (-w06 - 2.f * w16 - 4.f * w26 - 8.f * w36 - 16.f * w46 - 32.f * w56 - 64.f * w66) / 120.f;
1238
1239 OUTPUT_ROW_2x2_7x7(out4, tmp);
1240
1241 // Row 5
1242 VEC_DATA_TYPE(DATA_TYPE, 8)
1243 out5 = 0.0f;
1244
1245 tmp.s0 = (w00 - 3.f * w10 + 9.f * w20 - 27.f * w30 + 81.f * w40 - 243.f * w50 + 729.f * w60) / 720.f;
1246 tmp.s1 = (w01 - 3.f * w11 + 9.f * w21 - 27.f * w31 + 81.f * w41 - 243.f * w51 + 729.f * w61) / 720.f;
1247 tmp.s2 = (w02 - 3.f * w12 + 9.f * w22 - 27.f * w32 + 81.f * w42 - 243.f * w52 + 729.f * w62) / 720.f;
1248 tmp.s3 = (w03 - 3.f * w13 + 9.f * w23 - 27.f * w33 + 81.f * w43 - 243.f * w53 + 729.f * w63) / 720.f;
1249 tmp.s4 = (w04 - 3.f * w14 + 9.f * w24 - 27.f * w34 + 81.f * w44 - 243.f * w54 + 729.f * w64) / 720.f;
1250 tmp.s5 = (w05 - 3.f * w15 + 9.f * w25 - 27.f * w35 + 81.f * w45 - 243.f * w55 + 729.f * w65) / 720.f;
1251 tmp.s6 = (w06 - 3.f * w16 + 9.f * w26 - 27.f * w36 + 81.f * w46 - 243.f * w56 + 729.f * w66) / 720.f;
1252
1253 OUTPUT_ROW_2x2_7x7(out5, tmp);
1254
1255 // Row 6
1256 VEC_DATA_TYPE(DATA_TYPE, 8)
1257 out6 = 0.0f;
1258
1259 tmp.s0 = (w00 + 3.f * w10 + 9.f * w20 + 27.f * w30 + 81.f * w40 + 243.f * w50 + 729.f * w60) / 720.f;
1260 tmp.s1 = (w01 + 3.f * w11 + 9.f * w21 + 27.f * w31 + 81.f * w41 + 243.f * w51 + 729.f * w61) / 720.f;
1261 tmp.s2 = (w02 + 3.f * w12 + 9.f * w22 + 27.f * w32 + 81.f * w42 + 243.f * w52 + 729.f * w62) / 720.f;
1262 tmp.s3 = (w03 + 3.f * w13 + 9.f * w23 + 27.f * w33 + 81.f * w43 + 243.f * w53 + 729.f * w63) / 720.f;
1263 tmp.s4 = (w04 + 3.f * w14 + 9.f * w24 + 27.f * w34 + 81.f * w44 + 243.f * w54 + 729.f * w64) / 720.f;
1264 tmp.s5 = (w05 + 3.f * w15 + 9.f * w25 + 27.f * w35 + 81.f * w45 + 243.f * w55 + 729.f * w65) / 720.f;
1265 tmp.s6 = (w06 + 3.f * w16 + 9.f * w26 + 27.f * w36 + 81.f * w46 + 243.f * w56 + 729.f * w66) / 720.f;
1266
1267 OUTPUT_ROW_2x2_7x7(out6, tmp);
1268
1269 // Row 7
1270 VEC_DATA_TYPE(DATA_TYPE, 8)
1271 out7 = 0.0f;
1272
1273 tmp.s0 = w60;
1274 tmp.s1 = w61;
1275 tmp.s2 = w62;
1276 tmp.s3 = w63;
1277 tmp.s4 = w64;
1278 tmp.s5 = w65;
1279 tmp.s6 = w66;
1280
1281 OUTPUT_ROW_2x2_7x7(out7, tmp);
1282
1283#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1284
1285 int x0 = get_global_id(2); // idx filter
1286 int y0 = get_global_id(0); // idx channel
1287
1288 // Get output address
1289 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(DATA_TYPE) + y0 * dst_stride_y;
1290
1291 // Store the values across the channels
1292 *(__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z) = out0.s0;
1293 *(__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z) = out0.s1;
1294 *(__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z) = out0.s2;
1295 *(__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z) = out0.s3;
1296 *(__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z) = out0.s4;
1297 *(__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z) = out0.s5;
1298 *(__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z) = out0.s6;
1299 *(__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z) = out0.s7;
1300
1301#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1302 *(__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z) = out1.s0;
1303 *(__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z) = out1.s1;
1304 *(__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z) = out1.s2;
1305 *(__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z) = out1.s3;
1306 *(__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z) = out1.s4;
1307 *(__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z) = out1.s5;
1308 *(__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z) = out1.s6;
1309 *(__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z) = out1.s7;
1310 *(__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z) = out2.s0;
1311 *(__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z) = out2.s1;
1312 *(__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z) = out2.s2;
1313 *(__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z) = out2.s3;
1314 *(__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z) = out2.s4;
1315 *(__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z) = out2.s5;
1316 *(__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z) = out2.s6;
1317 *(__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z) = out2.s7;
1318 *(__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z) = out3.s0;
1319 *(__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z) = out3.s1;
1320 *(__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z) = out3.s2;
1321 *(__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z) = out3.s3;
1322 *(__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z) = out3.s4;
1323 *(__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z) = out3.s5;
1324 *(__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z) = out3.s6;
1325 *(__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z) = out3.s7;
1326 *(__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z) = out4.s0;
1327 *(__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z) = out4.s1;
1328 *(__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z) = out4.s2;
1329 *(__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z) = out4.s3;
1330 *(__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z) = out4.s4;
1331 *(__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z) = out4.s5;
1332 *(__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z) = out4.s6;
1333 *(__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z) = out4.s7;
1334 *(__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z) = out5.s0;
1335 *(__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z) = out5.s1;
1336 *(__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z) = out5.s2;
1337 *(__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z) = out5.s3;
1338 *(__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z) = out5.s4;
1339 *(__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z) = out5.s5;
1340 *(__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z) = out5.s6;
1341 *(__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z) = out5.s7;
1342 *(__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z) = out6.s0;
1343 *(__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z) = out6.s1;
1344 *(__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z) = out6.s2;
1345 *(__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z) = out6.s3;
1346 *(__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z) = out6.s4;
1347 *(__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z) = out6.s5;
1348 *(__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z) = out6.s6;
1349 *(__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z) = out6.s7;
1350 *(__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z) = out7.s0;
1351 *(__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z) = out7.s1;
1352 *(__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z) = out7.s2;
1353 *(__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z) = out7.s3;
1354 *(__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z) = out7.s4;
1355 *(__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z) = out7.s5;
1356 *(__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z) = out7.s6;
1357 *(__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z) = out7.s7;
1358#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1359}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001360#endif // defined(SRC_DIM_Z)
1361
1362#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1363/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 2x1
1364 *
1365 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1366 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001367 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001368 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001369 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001370 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1371 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1372 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1373 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1374 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1375 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1376 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1377 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1378 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1379 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1380 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1381 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1382 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1383 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1384 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1385 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1386 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1387 */
1388__kernel void winograd_filter_transform_2x1_3x1_nchw(
1389 TENSOR4D_DECLARATION(src),
1390 TENSOR3D_DECLARATION(dst))
1391{
1392 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
1393 src_stride_x,
1394 src_step_x,
1395 src_stride_y,
1396 src_step_y,
1397 src_stride_z,
1398 src_step_z,
1399 src_stride_w,
1400 src_step_w,
1401 src_offset_first_element_in_bytes,
1402 dst_ptr,
1403 dst_stride_x,
1404 dst_step_x,
1405 dst_stride_y,
1406 dst_step_y,
1407 dst_stride_z,
1408 dst_step_z,
1409 dst_offset_first_element_in_bytes);
1410}
1411
1412/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 4x1
1413 *
1414 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1415 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001416 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001417 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001418 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001419 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1420 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1421 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1422 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1423 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1424 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1425 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1426 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1427 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1428 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1429 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1430 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1431 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1432 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1433 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1434 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1435 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1436 */
1437__kernel void winograd_filter_transform_4x1_3x1_nchw(
1438 TENSOR4D_DECLARATION(src),
1439 TENSOR3D_DECLARATION(dst))
1440{
1441 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
1442 src_stride_x,
1443 src_step_x,
1444 src_stride_y,
1445 src_step_y,
1446 src_stride_z,
1447 src_step_z,
1448 src_stride_w,
1449 src_step_w,
1450 src_offset_first_element_in_bytes,
1451 dst_ptr,
1452 dst_stride_x,
1453 dst_step_x,
1454 dst_stride_y,
1455 dst_step_y,
1456 dst_stride_z,
1457 dst_step_z,
1458 dst_offset_first_element_in_bytes);
1459}
1460
1461/** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NCHW and the output tile is 4x1
1462 *
1463 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1464 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001465 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001466 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001467 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001468 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1469 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1470 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1471 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1472 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1473 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1474 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1475 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1476 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1477 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1478 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1479 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1480 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1481 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1482 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1483 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1484 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1485 */
1486__kernel void winograd_filter_transform_4x1_5x1_nchw(
1487 TENSOR4D_DECLARATION(src),
1488 TENSOR3D_DECLARATION(dst))
1489{
1490 winograd_filter_transform_4x4_5x5_nchw(src_ptr,
1491 src_stride_x,
1492 src_step_x,
1493 src_stride_y,
1494 src_step_y,
1495 src_stride_z,
1496 src_step_z,
1497 src_stride_w,
1498 src_step_w,
1499 src_offset_first_element_in_bytes,
1500 dst_ptr,
1501 dst_stride_x,
1502 dst_step_x,
1503 dst_stride_y,
1504 dst_step_y,
1505 dst_stride_z,
1506 dst_step_z,
1507 dst_offset_first_element_in_bytes);
1508}
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001509
1510/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NHWC and the output tile is 4x1
1511 *
1512 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1513 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001514 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001515 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001516 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001517 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1518 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1519 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1520 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1521 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1522 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1523 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1524 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1525 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1526 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1527 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1528 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1529 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1530 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1531 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1532 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1533 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1534 */
1535__kernel void winograd_filter_transform_4x1_3x1_nhwc(
1536 TENSOR4D_DECLARATION(src),
1537 TENSOR3D_DECLARATION(dst))
1538{
1539 winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
1540 src_stride_x,
1541 src_step_x,
1542 src_stride_y,
1543 src_step_y,
1544 src_stride_z,
1545 src_step_z,
1546 src_stride_w,
1547 src_step_w,
1548 src_offset_first_element_in_bytes,
1549 dst_ptr,
1550 dst_stride_x,
1551 dst_step_x,
1552 dst_stride_y,
1553 dst_step_y,
1554 dst_stride_z,
1555 dst_step_z,
1556 dst_offset_first_element_in_bytes);
1557}
1558
1559/** This OpenCL kernel performs Winograd filter transform 5x1 when the data layout is NHWC and the output tile is 4x1
1560 *
1561 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1562 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001563 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001564 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001565 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001566 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1567 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1568 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1569 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1570 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1571 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1572 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1573 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1574 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1575 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1576 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1577 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1578 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1579 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1580 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1581 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1582 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1583 */
1584__kernel void winograd_filter_transform_4x1_5x1_nhwc(
1585 TENSOR4D_DECLARATION(src),
1586 TENSOR3D_DECLARATION(dst))
1587{
1588 winograd_filter_transform_4x4_5x5_nhwc(src_ptr,
1589 src_stride_x,
1590 src_step_x,
1591 src_stride_y,
1592 src_step_y,
1593 src_stride_z,
1594 src_step_z,
1595 src_stride_w,
1596 src_step_w,
1597 src_offset_first_element_in_bytes,
1598 dst_ptr,
1599 dst_stride_x,
1600 dst_step_x,
1601 dst_stride_y,
1602 dst_step_y,
1603 dst_stride_z,
1604 dst_step_z,
1605 dst_offset_first_element_in_bytes);
1606}
Michele Di Giorgio881c6842019-02-27 14:26:51 +00001607
1608/** This OpenCL kernel performs Winograd filter transform 7x1 when the data layout is NHWC and the output tile is 2x1
1609 *
1610 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1611 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
1612 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float.
1613 *
1614 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
1615 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1616 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1617 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1618 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1619 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1620 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1621 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1622 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1623 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1624 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1625 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1626 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1627 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1628 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1629 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1630 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1631 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1632 */
1633__kernel void winograd_filter_transform_2x1_7x1_nhwc(
1634 TENSOR4D_DECLARATION(src),
1635 TENSOR3D_DECLARATION(dst))
1636{
1637 winograd_filter_transform_2x2_7x7_nhwc(src_ptr,
1638 src_stride_x,
1639 src_step_x,
1640 src_stride_y,
1641 src_step_y,
1642 src_stride_z,
1643 src_step_z,
1644 src_stride_w,
1645 src_step_w,
1646 src_offset_first_element_in_bytes,
1647 dst_ptr,
1648 dst_stride_x,
1649 dst_step_x,
1650 dst_stride_y,
1651 dst_step_y,
1652 dst_stride_z,
1653 dst_step_z,
1654 dst_offset_first_element_in_bytes);
1655}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001656#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1657
1658#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
1659/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x2
1660 *
1661 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1662 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001663 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001664 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001665 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001666 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1667 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1668 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1669 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1670 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1671 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1672 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1673 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1674 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1675 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1676 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1677 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1678 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1679 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1680 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1681 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1682 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1683 */
1684__kernel void winograd_filter_transform_1x2_1x3_nchw(
1685 TENSOR4D_DECLARATION(src),
1686 TENSOR3D_DECLARATION(dst))
1687{
1688 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
1689 src_stride_x,
1690 src_step_x,
1691 src_stride_y,
1692 src_step_y,
1693 src_stride_z,
1694 src_step_z,
1695 src_stride_w,
1696 src_step_w,
1697 src_offset_first_element_in_bytes,
1698 dst_ptr,
1699 dst_stride_x,
1700 dst_step_x,
1701 dst_stride_y,
1702 dst_step_y,
1703 dst_stride_z,
1704 dst_step_z,
1705 dst_offset_first_element_in_bytes);
1706}
1707
1708/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x4
1709 *
1710 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1711 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001712 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001713 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001714 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001715 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1716 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1717 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1718 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1719 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1720 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1721 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1722 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1723 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1724 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1725 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1726 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1727 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1728 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1729 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1730 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1731 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1732 */
1733__kernel void winograd_filter_transform_1x4_1x3_nchw(
1734 TENSOR4D_DECLARATION(src),
1735 TENSOR3D_DECLARATION(dst))
1736{
1737 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
1738 src_stride_x,
1739 src_step_x,
1740 src_stride_y,
1741 src_step_y,
1742 src_stride_z,
1743 src_step_z,
1744 src_stride_w,
1745 src_step_w,
1746 src_offset_first_element_in_bytes,
1747 dst_ptr,
1748 dst_stride_x,
1749 dst_step_x,
1750 dst_stride_y,
1751 dst_step_y,
1752 dst_stride_z,
1753 dst_step_z,
1754 dst_offset_first_element_in_bytes);
1755}
1756
1757/** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NCHW and the output tile is 1x4
1758 *
1759 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1760 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001761 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001762 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001763 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001764 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1765 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1766 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1767 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1768 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1769 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1770 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1771 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1772 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1773 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1774 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1775 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1776 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1777 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1778 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1779 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1780 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1781 */
1782__kernel void winograd_filter_transform_1x4_1x5_nchw(
1783 TENSOR4D_DECLARATION(src),
1784 TENSOR3D_DECLARATION(dst))
1785{
1786 winograd_filter_transform_4x4_5x5_nchw(src_ptr,
1787 src_stride_x,
1788 src_step_x,
1789 src_stride_y,
1790 src_step_y,
1791 src_stride_z,
1792 src_step_z,
1793 src_stride_w,
1794 src_step_w,
1795 src_offset_first_element_in_bytes,
1796 dst_ptr,
1797 dst_stride_x,
1798 dst_step_x,
1799 dst_stride_y,
1800 dst_step_y,
1801 dst_stride_z,
1802 dst_step_z,
1803 dst_offset_first_element_in_bytes);
1804}
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001805
1806/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NHWC and the output tile is 1x4
1807 *
1808 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1809 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001810 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001811 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001812 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001813 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1814 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1815 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1816 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1817 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1818 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1819 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1820 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1821 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1822 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1823 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1824 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1825 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1826 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1827 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1828 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1829 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1830 */
1831__kernel void winograd_filter_transform_1x4_1x3_nhwc(
1832 TENSOR4D_DECLARATION(src),
1833 TENSOR3D_DECLARATION(dst))
1834{
1835 winograd_filter_transform_4x4_3x3_nhwc(src_ptr,
1836 src_stride_x,
1837 src_step_x,
1838 src_stride_y,
1839 src_step_y,
1840 src_stride_z,
1841 src_step_z,
1842 src_stride_w,
1843 src_step_w,
1844 src_offset_first_element_in_bytes,
1845 dst_ptr,
1846 dst_stride_x,
1847 dst_step_x,
1848 dst_stride_y,
1849 dst_step_y,
1850 dst_stride_z,
1851 dst_step_z,
1852 dst_offset_first_element_in_bytes);
1853}
1854
1855/** This OpenCL kernel performs Winograd filter transform 1x5 when the data layout is NHWC and the output tile is 1x4
1856 *
1857 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1858 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001859 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001860 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001861 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001862 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1863 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1864 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1865 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1866 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1867 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1868 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1869 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1870 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1871 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1872 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1873 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1874 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1875 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1876 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1877 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1878 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1879 */
1880__kernel void winograd_filter_transform_1x4_1x5_nhwc(
1881 TENSOR4D_DECLARATION(src),
1882 TENSOR3D_DECLARATION(dst))
1883{
1884 winograd_filter_transform_4x4_5x5_nhwc(src_ptr,
1885 src_stride_x,
1886 src_step_x,
1887 src_stride_y,
1888 src_step_y,
1889 src_stride_z,
1890 src_step_z,
1891 src_stride_w,
1892 src_step_w,
1893 src_offset_first_element_in_bytes,
1894 dst_ptr,
1895 dst_stride_x,
1896 dst_step_x,
1897 dst_stride_y,
1898 dst_step_y,
1899 dst_stride_z,
1900 dst_step_z,
1901 dst_offset_first_element_in_bytes);
1902}
Michele Di Giorgio881c6842019-02-27 14:26:51 +00001903
1904/** This OpenCL kernel performs Winograd filter transform 1x7 when the data layout is NHWC and the output tile is 1x2
1905 *
1906 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
1907 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
1908 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float.
1909 *
1910 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32/F16
1911 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1912 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1913 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1914 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1915 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1916 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1917 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1918 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
1919 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1920 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1921 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1922 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1923 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1924 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1925 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1926 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1927 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1928 */
1929__kernel void winograd_filter_transform_1x2_1x7_nhwc(
1930 TENSOR4D_DECLARATION(src),
1931 TENSOR3D_DECLARATION(dst))
1932{
1933 winograd_filter_transform_2x2_7x7_nhwc(src_ptr,
1934 src_stride_x,
1935 src_step_x,
1936 src_stride_y,
1937 src_step_y,
1938 src_stride_z,
1939 src_step_z,
1940 src_stride_w,
1941 src_step_w,
1942 src_offset_first_element_in_bytes,
1943 dst_ptr,
1944 dst_stride_x,
1945 dst_step_x,
1946 dst_stride_y,
1947 dst_step_y,
1948 dst_stride_z,
1949 dst_step_z,
1950 dst_offset_first_element_in_bytes);
1951}
Giorgio Arena149fdf32018-07-04 17:03:33 +01001952#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)