blob: 34bf2902e8ddbaea8dba0a15353348f6a5eda1a6 [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodiced28b7512018-07-06 12:59:28 +010026#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
27 ({ \
28 comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
29 comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
30 comm_fact.s2 = 2.5f * tmp.s3; \
31 comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
32 comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
33 comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
34 comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
35 \
36 out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
37 out.s1 = comm_fact.s0 + comm_fact.s1; \
38 out.s2 = comm_fact.s0 - comm_fact.s1; \
39 out.s3 = comm_fact.s3 + comm_fact.s4; \
40 out.s4 = comm_fact.s4 - comm_fact.s3; \
41 out.s5 = comm_fact.s5 + comm_fact.s6; \
42 out.s6 = comm_fact.s5 - comm_fact.s6; \
43 out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
44 })
45
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010046#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
47/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
48 *
49 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
50 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
51 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
52 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
53 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
54 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010055 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010056 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010057 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010058 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
59 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
60 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
61 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
62 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
63 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
64 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
65 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
66 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
67 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
68 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
69 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
70 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
71 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
72 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +010073 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
74 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010075 */
76__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
77 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +010078 TENSOR3D_DECLARATION(dst),
79 uint src_stride_w,
80 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010081{
Georgios Pinitasc55beee2018-10-23 15:23:23 +010082 const int x = get_global_id(0);
83 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +000084#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +010085 const int z = get_global_id(2) % SRC_DEPTH;
86 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +000087#else /* defined(SRC_DEPTH) */
88 const int z = get_global_id(2);
89#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010090
91 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +000092#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +010093 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +000094#else /* defined(SRC_DEPTH) */
95 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
96#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010097
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010098 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010099
100#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100101 VEC_DATA_TYPE(DATA_TYPE, 4)
102 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100103#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100104 VEC_DATA_TYPE(DATA_TYPE, 4)
105 in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
106 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
107 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
108 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100109#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100110 VEC_DATA_TYPE(DATA_TYPE, 4)
111 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
112 VEC_DATA_TYPE(DATA_TYPE, 4)
113 in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
114 VEC_DATA_TYPE(DATA_TYPE, 4)
115 in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
116 VEC_DATA_TYPE(DATA_TYPE, 4)
117 in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100118#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
119
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100120 VEC_DATA_TYPE(DATA_TYPE, 4)
121 tmp0 = in_row0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100122
123#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
124 tmp0 -= in_row2;
125#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
126
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100127 DATA_TYPE out00 = tmp0.s0 - tmp0.s2;
128 DATA_TYPE out01 = tmp0.s1 + tmp0.s2;
129 DATA_TYPE out02 = tmp0.s2 - tmp0.s1;
130 DATA_TYPE out03 = tmp0.s1 - tmp0.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100131
132#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100133 VEC_DATA_TYPE(DATA_TYPE, 4)
134 tmp1 = in_row1 + in_row2;
135 VEC_DATA_TYPE(DATA_TYPE, 4)
136 tmp2 = in_row2 - in_row1;
137 VEC_DATA_TYPE(DATA_TYPE, 4)
138 tmp3 = in_row1 - in_row3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100139
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100140 DATA_TYPE out10 = tmp1.s0 - tmp1.s2;
141 DATA_TYPE out11 = tmp1.s1 + tmp1.s2;
142 DATA_TYPE out12 = tmp1.s2 - tmp1.s1;
143 DATA_TYPE out13 = tmp1.s1 - tmp1.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100144
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100145 DATA_TYPE out20 = tmp2.s0 - tmp2.s2;
146 DATA_TYPE out21 = tmp2.s1 + tmp2.s2;
147 DATA_TYPE out22 = tmp2.s2 - tmp2.s1;
148 DATA_TYPE out23 = tmp2.s1 - tmp2.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100149
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100150 DATA_TYPE out30 = tmp3.s0 - tmp3.s2;
151 DATA_TYPE out31 = tmp3.s1 + tmp3.s2;
152 DATA_TYPE out32 = tmp3.s2 - tmp3.s1;
153 DATA_TYPE out33 = tmp3.s1 - tmp3.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100154#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
155
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000156#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100157 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000158#else /* defined(SRC_DEPTH) */
159 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
160#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100161
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100162 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
163 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
164 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
165 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100166
167#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100168 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out10;
169 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out11;
170 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out12;
171 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out13;
172 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out20;
173 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out21;
174 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out22;
175 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out23;
176 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out30;
177 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out31;
178 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out32;
179 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out33;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100180#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
181}
182
183/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
184 *
185 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
186 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
187 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
188 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
189 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
190 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100191 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100192 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100193 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100194 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
195 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
196 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
197 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
198 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
199 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
200 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
201 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
202 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
203 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
204 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
205 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
206 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
207 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
208 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100209 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
210 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100211 */
212__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
213 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100214 TENSOR3D_DECLARATION(dst),
215 uint src_stride_w,
216 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100217{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100218 const int x = get_global_id(0);
219 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000220#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100221 const int z = (get_global_id(2) * 2) % SRC_DEPTH;
222 const int b = (get_global_id(2) * 2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000223#else /* defined(SRC_DEPTH) */
224 const int z = get_global_id(2) * 2;
225#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100226
227 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000228#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100229 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000230#else /* defined(SRC_DEPTH) */
231 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
232#endif /* defined(SRC_DEPTH) */
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100233 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100234
235#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100236 VEC_DATA_TYPE(DATA_TYPE, 4)
237 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100238#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100239 VEC_DATA_TYPE(DATA_TYPE, 4)
240 in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
241 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
242 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
243 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100244#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100245 VEC_DATA_TYPE(DATA_TYPE, 4)
246 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
247 VEC_DATA_TYPE(DATA_TYPE, 4)
248 in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
249 VEC_DATA_TYPE(DATA_TYPE, 4)
250 in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
251 VEC_DATA_TYPE(DATA_TYPE, 4)
252 in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100253#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
254
255 src_addr += src_stride_z;
256#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100257 VEC_DATA_TYPE(DATA_TYPE, 4)
258 in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100259#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100260 VEC_DATA_TYPE(DATA_TYPE, 4)
261 in_row4 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
262 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
263 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
264 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100265#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100266 VEC_DATA_TYPE(DATA_TYPE, 4)
267 in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
268 VEC_DATA_TYPE(DATA_TYPE, 4)
269 in_row5 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
270 VEC_DATA_TYPE(DATA_TYPE, 4)
271 in_row6 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
272 VEC_DATA_TYPE(DATA_TYPE, 4)
273 in_row7 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100274#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
275
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100276 VEC_DATA_TYPE(DATA_TYPE, 4)
277 tmp0 = in_row0;
278 VEC_DATA_TYPE(DATA_TYPE, 4)
279 tmp4 = in_row4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100280
281#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
282 tmp0 -= in_row2;
283 tmp4 -= in_row6;
284#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
285
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100286 VEC_DATA_TYPE(DATA_TYPE, 2)
287 out00 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
288 VEC_DATA_TYPE(DATA_TYPE, 2)
289 out01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
290 VEC_DATA_TYPE(DATA_TYPE, 2)
291 out02 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
292 VEC_DATA_TYPE(DATA_TYPE, 2)
293 out03 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100294
295#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100296 VEC_DATA_TYPE(DATA_TYPE, 4)
297 tmp1 = in_row1 + in_row2;
298 VEC_DATA_TYPE(DATA_TYPE, 4)
299 tmp2 = in_row2 - in_row1;
300 VEC_DATA_TYPE(DATA_TYPE, 4)
301 tmp3 = in_row1 - in_row3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100302
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100303 VEC_DATA_TYPE(DATA_TYPE, 4)
304 tmp5 = in_row5 + in_row6;
305 VEC_DATA_TYPE(DATA_TYPE, 4)
306 tmp6 = in_row6 - in_row5;
307 VEC_DATA_TYPE(DATA_TYPE, 4)
308 tmp7 = in_row5 - in_row7;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100309
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100310 VEC_DATA_TYPE(DATA_TYPE, 2)
311 out10 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
312 VEC_DATA_TYPE(DATA_TYPE, 2)
313 out11 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
314 VEC_DATA_TYPE(DATA_TYPE, 2)
315 out12 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
316 VEC_DATA_TYPE(DATA_TYPE, 2)
317 out13 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100318
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100319 VEC_DATA_TYPE(DATA_TYPE, 2)
320 out20 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
321 VEC_DATA_TYPE(DATA_TYPE, 2)
322 out21 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
323 VEC_DATA_TYPE(DATA_TYPE, 2)
324 out22 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
325 VEC_DATA_TYPE(DATA_TYPE, 2)
326 out23 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100327
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100328 VEC_DATA_TYPE(DATA_TYPE, 2)
329 out30 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
330 VEC_DATA_TYPE(DATA_TYPE, 2)
331 out31 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
332 VEC_DATA_TYPE(DATA_TYPE, 2)
333 out32 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
334 VEC_DATA_TYPE(DATA_TYPE, 2)
335 out33 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100336#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
337
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000338#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100339 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000340#else /* defined(SRC_DEPTH) */
341 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
342#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100343
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100344 vstore2(out00, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z));
345 vstore2(out01, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z));
346 vstore2(out02, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z));
347 vstore2(out03, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100348
349#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100350 vstore2(out10, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z));
351 vstore2(out11, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z));
352 vstore2(out12, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z));
353 vstore2(out13, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z));
354 vstore2(out20, 0, (__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z));
355 vstore2(out21, 0, (__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z));
356 vstore2(out22, 0, (__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z));
357 vstore2(out23, 0, (__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z));
358 vstore2(out30, 0, (__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z));
359 vstore2(out31, 0, (__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z));
360 vstore2(out32, 0, (__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z));
361 vstore2(out33, 0, (__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100362#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
363}
364
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100365/** This OpenCL kernel computes the input transform when the output tile is 4x4/4x1 or 1x4, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100366 *
367 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
368 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
369 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
370 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
371 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
372 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100373 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100374 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100375 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100376 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
377 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
378 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
379 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
380 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
381 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
382 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
383 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
384 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
385 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
386 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
387 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
388 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
389 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
390 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100391 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
392 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100393 */
394__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
395 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100396 TENSOR3D_DECLARATION(dst),
397 uint src_stride_w,
398 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100399{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100400 const int x = get_global_id(0);
401 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000402#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100403 const int z = get_global_id(2) % SRC_DEPTH;
404 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000405#else /* defined(SRC_DEPTH) */
406 const int z = get_global_id(2);
407#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100408
409 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000410#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100411 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000412#else /* defined(SRC_DEPTH) */
413 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
414#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100415
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100416 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100417
418#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
419 // Row0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100420 VEC_DATA_TYPE(DATA_TYPE, 4)
421 d00 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
422 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
423 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
424 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
425 VEC_DATA_TYPE(DATA_TYPE, 2)
426 d01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(*((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
427 *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100428#else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
429 // Row0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100430 VEC_DATA_TYPE(DATA_TYPE, 4)
431 d00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
432 VEC_DATA_TYPE(DATA_TYPE, 2)
433 d01 = vload2(2, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100434#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
435
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100436 DATA_TYPE out0 = 0.0f;
437 DATA_TYPE out1 = 0.0f;
438 DATA_TYPE out2 = 0.0f;
439 DATA_TYPE out3 = 0.0f;
440 DATA_TYPE out4 = 0.0f;
441 DATA_TYPE out5 = 0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100442
443 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
444 out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
445 out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
446 out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
447 out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
448 out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
449 out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
450
451#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
452 // Row4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100453 VEC_DATA_TYPE(DATA_TYPE, 4)
454 d40 = vload4(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
455 VEC_DATA_TYPE(DATA_TYPE, 2)
456 d41 = vload2(2, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100457
458 // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100459 DATA_TYPE k0 = d41.s0;
460 DATA_TYPE k1 = d41.s0;
461 DATA_TYPE k2 = d41.s0;
462 DATA_TYPE k3 = d41.s0;
463 DATA_TYPE k4 = d41.s0;
464 DATA_TYPE k5 = 0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100465
466 k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
467 k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
468 k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
469 k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
470 k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
471 k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
472
473 out0 += k0;
474 out1 += k1;
475 out2 += k2;
476 out3 += k3;
477 out4 += k4;
478 out5 += k5;
479
480 // Row2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100481 VEC_DATA_TYPE(DATA_TYPE, 4)
482 d20 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
483 VEC_DATA_TYPE(DATA_TYPE, 2)
484 d21 = vload2(2, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100485
486 out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
487 out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
488 out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
489 out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
490 out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
491 out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
492#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
493
494 // Compute destination address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000495#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100496 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000497#else /* defined(SRC_DEPTH) */
498 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
499#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100500
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100501 uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100502
503 *(dst_addr) = out0;
504 dst_addr += dst_plane_stride;
505 *(dst_addr) = out1;
506 dst_addr += dst_plane_stride;
507 *(dst_addr) = out2;
508 dst_addr += dst_plane_stride;
509 *(dst_addr) = out3;
510 dst_addr += dst_plane_stride;
511 *(dst_addr) = out4;
512 dst_addr += dst_plane_stride;
513 *(dst_addr) = out5;
514 dst_addr += dst_plane_stride;
515
516#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100517 DATA_TYPE out6 = k0;
518 DATA_TYPE out7 = k1;
519 DATA_TYPE out8 = k2;
520 DATA_TYPE out9 = k3;
521 DATA_TYPE out10 = k4;
522 DATA_TYPE out11 = k5;
523 DATA_TYPE out12 = k0;
524 DATA_TYPE out13 = k1;
525 DATA_TYPE out14 = k2;
526 DATA_TYPE out15 = k3;
527 DATA_TYPE out16 = k4;
528 DATA_TYPE out17 = k5;
529 DATA_TYPE out18 = k0;
530 DATA_TYPE out19 = k1;
531 DATA_TYPE out20 = k2;
532 DATA_TYPE out21 = k3;
533 DATA_TYPE out22 = k4;
534 DATA_TYPE out23 = k5;
535 DATA_TYPE out24 = k0;
536 DATA_TYPE out25 = k1;
537 DATA_TYPE out26 = k2;
538 DATA_TYPE out27 = k3;
539 DATA_TYPE out28 = k4;
540 DATA_TYPE out29 = k5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100541
542 // Row1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100543 VEC_DATA_TYPE(DATA_TYPE, 4)
544 d10 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
545 VEC_DATA_TYPE(DATA_TYPE, 2)
546 d11 = vload2(2, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100547
548 // Row3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100549 VEC_DATA_TYPE(DATA_TYPE, 4)
550 d30 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
551 VEC_DATA_TYPE(DATA_TYPE, 2)
552 d31 = vload2(2, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100553
554 // Compute common parts for the channels between [6, 29]
555 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
556 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100557 DATA_TYPE part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
558 DATA_TYPE part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
559 DATA_TYPE part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
560 DATA_TYPE part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
561 DATA_TYPE part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
562 DATA_TYPE part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
563 DATA_TYPE part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
564 DATA_TYPE part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
565 DATA_TYPE part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
566 DATA_TYPE part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
567 DATA_TYPE part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
568 DATA_TYPE part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100569
570 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
571 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100572 DATA_TYPE part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
573 DATA_TYPE part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
574 DATA_TYPE part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
575 DATA_TYPE part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
576 DATA_TYPE part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
577 DATA_TYPE part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
578 DATA_TYPE part18 = part6 * 0.25f; // d20.s2 - d21.s0
579 DATA_TYPE part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
580 DATA_TYPE part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
581 DATA_TYPE part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
582 DATA_TYPE part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
583 DATA_TYPE part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100584
585 out6 += part0 - part1;
586 out12 += part0 + part1;
587 out7 += part2 + part3 + part4 + part5;
588 out8 += part2 - part3 + part4 - part5;
589 out13 += part2 + part3 - part4 - part5;
590 out14 += part2 - part3 - part4 + part5;
591 out9 += part6 + part7 + part8 + part9;
592 out10 += part6 - part7 + part8 - part9;
593 out15 += part6 - part7 - part8 + part9;
594 out16 += part6 + part7 - part8 - part9;
595 out11 += part10 + part11;
596 out17 += part10 - part11;
597
598 out18 += part13 - part12;
599 out24 += part13 + part12;
600 out19 += part14 + part15 + part16 + part17;
601 out20 += part14 - part15 + part16 - part17;
602 out25 += part14 - part15 - part16 + part17;
603 out26 += part14 + part15 - part16 - part17;
604 out21 += part18 + part19 + part20 + part21;
605 out22 += part18 - part19 + part20 - part21;
606 out27 += part18 - part19 - part20 + part21;
607 out28 += part18 + part19 - part20 - part21;
608 out23 += part22 + part23;
609 out29 += part22 - part23;
610
611 *(dst_addr) = out6;
612 dst_addr += dst_plane_stride;
613 *(dst_addr) = out7;
614 dst_addr += dst_plane_stride;
615 *(dst_addr) = out8;
616 dst_addr += dst_plane_stride;
617 *(dst_addr) = out9;
618 dst_addr += dst_plane_stride;
619 *(dst_addr) = out10;
620 dst_addr += dst_plane_stride;
621 *(dst_addr) = out11;
622 dst_addr += dst_plane_stride;
623 *(dst_addr) = out12;
624 dst_addr += dst_plane_stride;
625 *(dst_addr) = out13;
626 dst_addr += dst_plane_stride;
627 *(dst_addr) = out14;
628 dst_addr += dst_plane_stride;
629 *(dst_addr) = out15;
630 dst_addr += dst_plane_stride;
631 *(dst_addr) = out16;
632 dst_addr += dst_plane_stride;
633 *(dst_addr) = out17;
634 dst_addr += dst_plane_stride;
635
636 *(dst_addr) = out18;
637 dst_addr += dst_plane_stride;
638 *(dst_addr) = out19;
639 dst_addr += dst_plane_stride;
640 *(dst_addr) = out20;
641 dst_addr += dst_plane_stride;
642 *(dst_addr) = out21;
643 dst_addr += dst_plane_stride;
644 *(dst_addr) = out22;
645 dst_addr += dst_plane_stride;
646 *(dst_addr) = out23;
647 dst_addr += dst_plane_stride;
648 *(dst_addr) = out24;
649 dst_addr += dst_plane_stride;
650 *(dst_addr) = out25;
651 dst_addr += dst_plane_stride;
652 *(dst_addr) = out26;
653 dst_addr += dst_plane_stride;
654 *(dst_addr) = out27;
655 dst_addr += dst_plane_stride;
656 *(dst_addr) = out28;
657 dst_addr += dst_plane_stride;
658 *(dst_addr) = out29;
659 dst_addr += dst_plane_stride;
660
661 // Row5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100662 VEC_DATA_TYPE(DATA_TYPE, 4)
663 d50 = vload4(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
664 VEC_DATA_TYPE(DATA_TYPE, 2)
665 d51 = vload2(2, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100666
667 // Channels [30, 35]
668 out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
669 out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
670 out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
671 out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
672 out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
673 out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
674
675 *(dst_addr) = out0;
676 dst_addr += dst_plane_stride;
677 *(dst_addr) = out1;
678 dst_addr += dst_plane_stride;
679 *(dst_addr) = out2;
680 dst_addr += dst_plane_stride;
681 *(dst_addr) = out3;
682 dst_addr += dst_plane_stride;
683 *(dst_addr) = out4;
684 dst_addr += dst_plane_stride;
685 *(dst_addr) = out5;
686 dst_addr += dst_plane_stride;
687#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
688}
689
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100690/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NCHW
691 *
692 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
693 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
694 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
695 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
696 * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
697 * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
698 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
699 *
700 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
701 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
702 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
703 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
704 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
705 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
706 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
707 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
708 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
709 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
710 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
711 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
712 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
713 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
714 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
715 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
716 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
717 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
718 */
719__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
720 TENSOR3D_DECLARATION(src),
721 TENSOR3D_DECLARATION(dst),
722 uint src_stride_w,
723 uint dst_stride_w)
724{
725 const int x = get_global_id(0);
726 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000727#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100728 const int z = get_global_id(2) % SRC_DEPTH;
729 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000730#else /* defined(SRC_DEPTH) */
731 const int z = get_global_id(2);
732#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100733
734 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000735#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100736 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000737#else /* defined(SRC_DEPTH) */
738 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
739#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100740 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
741
742 // Load input tile
743#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
744 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr));
745#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
746 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 8))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
747 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
748 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
749 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)),
750 *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
751 *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)),
752 *((__global DATA_TYPE *)(src_addr + 6 * src_stride_y)),
753 *((__global DATA_TYPE *)(src_addr + 7 * src_stride_y)));
754#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
755 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
756 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row1 = vload8(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
757 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row2 = vload8(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
758 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row3 = vload8(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
759 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row4 = vload8(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
760 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row5 = vload8(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
761 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row6 = vload8(0, (__global DATA_TYPE *)(src_addr + 6 * src_stride_y));
762 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row7 = vload8(0, (__global DATA_TYPE *)(src_addr + 7 * src_stride_y));
763#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
764
765 // Calculate common factors for intermediate tensor
766 VEC_DATA_TYPE(DATA_TYPE, 8)
767 tmp0 = in_row0;
768 VEC_DATA_TYPE(DATA_TYPE, 8)
769 comm_fact0 = 0.0f;
770
771#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
772 comm_fact0 += in_row2 + in_row6 - (DATA_TYPE)4.25 * in_row4;
773 tmp0 += -in_row6 + (DATA_TYPE)5.25 * in_row4 - (DATA_TYPE)5.25 * in_row2;
774
775 VEC_DATA_TYPE(DATA_TYPE, 8)
776 comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25 * in_row3;
777 VEC_DATA_TYPE(DATA_TYPE, 8)
778 comm_fact2 = (DATA_TYPE)0.25 * in_row2 - (DATA_TYPE)1.25 * in_row4 + in_row6;
779
780 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
781 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
782
783 comm_fact0 = (DATA_TYPE)2.5 * in_row3;
784 comm_fact1 = (DATA_TYPE)0.5 * in_row1 - comm_fact0 + (DATA_TYPE)2.0 * in_row5;
785
786 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
787 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
788
789 comm_fact1 = (DATA_TYPE)2.0 * in_row1 - comm_fact0 + (DATA_TYPE)0.5 * in_row5;
790 comm_fact2 = (DATA_TYPE)4.0 * in_row2 - (DATA_TYPE)5.0 * in_row4 + in_row6;
791
792 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
793 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
794 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25 * in_row3 - (DATA_TYPE)5.25 * in_row5;
795#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
796
797 // Calculate output rows (reuse comm_fact0 vector)
798 VEC_DATA_TYPE(DATA_TYPE, 8)
799 out0;
800
801 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
802
803#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
804 VEC_DATA_TYPE(DATA_TYPE, 8)
805 out1, out2, out3, out4, out5, out6, out7;
806
807 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
808 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
809 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
810 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
811 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
812 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
813 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
814#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
815
816 // Store values across the channels
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000817#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100818 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000819#else /* defined(SRC_DEPTH) */
820 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
821#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100822
823 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
824 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
825 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
826 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
827 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
828 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
829 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
830 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
831
832#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
833 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
834 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
835 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
836 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
837 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
838 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
839 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
840 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
841 *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
842 *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
843 *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
844 *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
845 *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
846 *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
847 *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
848 *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
849 *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
850 *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
851 *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
852 *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
853 *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
854 *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
855 *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
856 *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
857 *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
858 *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
859 *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
860 *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
861 *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
862 *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
863 *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
864 *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
865 *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
866 *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
867 *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
868 *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
869 *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
870 *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
871 *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
872 *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
873 *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
874 *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
875 *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
876 *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
877 *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
878 *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
879 *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
880 *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
881 *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
882 *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
883 *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
884 *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
885 *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
886 *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
887 *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
888 *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
889#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
890}
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100891
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000892#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arena149fdf32018-07-04 17:03:33 +0100893/** This OpenCL kernel computes the input transform when the output tile is 4x4, 4x1 or 1x4, the filter size 3x3, 3x1 or 1x3 and the data layout is NHWC
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100894 *
895 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
896 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
897 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
898 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arena149fdf32018-07-04 17:03:33 +0100899 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
900 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
901 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
902 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100903 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100904 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100905 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100906 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
907 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
908 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
909 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
910 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
911 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
912 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
913 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
914 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
915 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
916 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
917 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
918 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
919 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
920 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100921 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
922 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100923 */
924__kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
925 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100926 TENSOR3D_DECLARATION(dst),
927 uint src_stride_w,
928 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100929{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100930 const int x = get_global_id(0);
931 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000932#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100933 const int z = get_global_id(2) % NUM_TILES_Y;
934 const int b = get_global_id(2) / NUM_TILES_Y;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000935#else /* defined(NUM_TILES_Y) */
936 const int z = get_global_id(2);
937#endif /* defined(NUM_TILES_Y) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100938
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000939#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100940 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000941#else /* defined(NUM_TILES_Y) */
942 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
943#endif /* defined(NUM_TILES_Y) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100944
945 // Clamp coordinates. This clamp is valid for all rows
Giorgio Arena149fdf32018-07-04 17:03:33 +0100946 int4 y_coord0 = (int4)(y * OUTPUT_TILE_W) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT;
947 int2 y_coord1 = (int2)(y * OUTPUT_TILE_W) + (int2)(4, 5) - (int2)PAD_LEFT;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100948 y_coord0 = clamp(y_coord0, (int4) - 1, (int4)SRC_DIM_1);
949 y_coord1 = clamp(y_coord1, (int2) - 1, (int2)SRC_DIM_1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100950
Giorgio Arena149fdf32018-07-04 17:03:33 +0100951 int z_coord;
952 int4 valid_y0;
953 int2 valid_y1;
954
955#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100956 // Row4
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100957 z_coord = (z * 4) - (int)PAD_TOP + 4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100958
959 // If z < 0, set y to -1
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100960 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
961 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100962 // If z >= SRC_DIM_2, set y to SRC_DIM_2
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100963 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
964 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100965
966 // Clamp z coordinate
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100967 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100968
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100969 DATA_TYPE d40 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
970 DATA_TYPE d41 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
971 DATA_TYPE d42 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
972 DATA_TYPE d43 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
973 DATA_TYPE d44 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
974 DATA_TYPE d45 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100975
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100976 DATA_TYPE k0 = d44;
977 DATA_TYPE k1 = d44;
978 DATA_TYPE k2 = d44;
979 DATA_TYPE k3 = d44;
980 DATA_TYPE k4 = d44;
981 DATA_TYPE k5 = (DATA_TYPE)0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100982
983 k0 += 4.0f * d40 - 5.0f * d42;
984 k1 += -4.0f * d41 - 4.0f * d42 + d43;
985 k2 += 4.0f * d41 - 4.0f * d42 - d43;
986 k3 += -2.0f * d41 + 2.0f * d43 - d42;
987 k4 += 2.0f * d41 - 2.0f * d43 - d42;
988 k5 += 4.0f * d41 - 5.0f * d43 + d45;
Giorgio Arena149fdf32018-07-04 17:03:33 +0100989#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100990
Giorgio Arena149fdf32018-07-04 17:03:33 +0100991#if !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100992 // Row0
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100993 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100994
995#if PAD_TOP != 0
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100996 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
997 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
998 valid_y0 = select(valid_y0, (int)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
999 valid_y1 = select(valid_y1, (int)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
1000 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001001#else // PAD_TOP != 0
1002 valid_y0 = y_coord0;
1003 valid_y1 = y_coord1;
1004#endif // if PAD_TOP == 0, we cannot read out of bound
1005
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001006 DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1007 DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1008 DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1009 DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1010 DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1011 DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arena149fdf32018-07-04 17:03:33 +01001012#else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1013 int4 z_coords0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP;
1014 int2 z_coords1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001015
Giorgio Arena149fdf32018-07-04 17:03:33 +01001016 valid_y0 = select((int4)y_coord0.s0, (int4) - 1, z_coords0 < (int4)0);
1017 valid_y1 = select((int2)y_coord0.s0, (int2) - 1, z_coords1 < (int2)0);
1018 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, z_coords0 >= (int4)SRC_DIM_2);
1019 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, z_coords1 >= (int2)SRC_DIM_2);
1020
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001021 z_coords0 = clamp((int4)z_coords0, (int4)0, (int4)((int)SRC_DIM_2 - 1));
1022 z_coords1 = clamp((int2)z_coords1, (int2)0, (int2)((int)SRC_DIM_2 - 1));
Giorgio Arena149fdf32018-07-04 17:03:33 +01001023
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001024 DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coords0.s0 * src_stride_z);
1025 DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coords0.s1 * src_stride_z);
1026 DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coords0.s2 * src_stride_z);
1027 DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coords0.s3 * src_stride_z);
1028 DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coords1.s0 * src_stride_z);
1029 DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coords1.s1 * src_stride_z);
Giorgio Arena149fdf32018-07-04 17:03:33 +01001030#endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1031
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001032 DATA_TYPE out0 = 16.0f * d00 - 20.0f * d02 + 4.0f * d04;
1033 DATA_TYPE out1 = -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 4.0f * d04;
1034 DATA_TYPE out2 = 16.0f * d01 - 16.0f * d02 - 4.0f * d03 + 4.0f * d04;
1035 DATA_TYPE out3 = -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 4.0f * d04;
1036 DATA_TYPE out4 = 8.0f * d01 - 4.0f * d02 - 8.0f * d03 + 4.0f * d04;
1037 DATA_TYPE out5 = 16.0f * d01 - 20.0f * d03 + 4.0f * d05;
Giorgio Arena149fdf32018-07-04 17:03:33 +01001038
1039#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001040 // Row2
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001041 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
1042 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
1043 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
1044 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
1045 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
1046 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001047
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001048 DATA_TYPE d20 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1049 DATA_TYPE d21 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1050 DATA_TYPE d22 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1051 DATA_TYPE d23 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1052 DATA_TYPE d24 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1053 DATA_TYPE d25 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001054
Giorgio Arena149fdf32018-07-04 17:03:33 +01001055 out0 += k0;
1056 out1 += k1;
1057 out2 += k2;
1058 out3 += k3;
1059 out4 += k4;
1060 out5 += k5;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001061 DATA_TYPE out6 = k0;
1062 DATA_TYPE out7 = k1;
1063 DATA_TYPE out8 = k2;
1064 DATA_TYPE out9 = k3;
1065 DATA_TYPE out10 = k4;
1066 DATA_TYPE out11 = k5;
1067 DATA_TYPE out12 = k0;
1068 DATA_TYPE out13 = k1;
1069 DATA_TYPE out14 = k2;
1070 DATA_TYPE out15 = k3;
1071 DATA_TYPE out16 = k4;
1072 DATA_TYPE out17 = k5;
1073 DATA_TYPE out18 = k0;
1074 DATA_TYPE out19 = k1;
1075 DATA_TYPE out20 = k2;
1076 DATA_TYPE out21 = k3;
1077 DATA_TYPE out22 = k4;
1078 DATA_TYPE out23 = k5;
1079 DATA_TYPE out24 = k0;
1080 DATA_TYPE out25 = k1;
1081 DATA_TYPE out26 = k2;
1082 DATA_TYPE out27 = k3;
1083 DATA_TYPE out28 = k4;
1084 DATA_TYPE out29 = k5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001085
1086 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
Giorgio Arena149fdf32018-07-04 17:03:33 +01001087 out0 += -20.0f * d20 + 25.0f * d22 - 5.0f * d24;
1088 out1 += 20.0f * d21 + 20.0f * d22 - 5.0f * d23 - 5.0f * d24;
1089 out2 += -20.0f * d21 + 20.0f * d22 + 5.0f * d23 - 5.0f * d24;
1090 out3 += 10.0f * d21 + 5.0f * d22 - 10.0f * d23 - 5.0f * d24;
1091 out4 += -10.0f * d21 + 5.0f * d22 + 10.0f * d23 - 5.0f * d24;
1092 out5 += -20.0f * d21 + 25.0f * d23 - 5.0f * d25;
1093#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1094
1095 // Compute destination address
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001096#if defined(NUM_TILES_Y)
1097 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
1098#else /* defined(NUM_TILES_Y) */
1099 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
1100#endif /* defined(NUM_TILES_Y) */
1101
1102 uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001103
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001104 *((__global DATA_TYPE *)dst_addr) = out0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001105 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001106 *((__global DATA_TYPE *)dst_addr) = out1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001107 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001108 *((__global DATA_TYPE *)dst_addr) = out2;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001109 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001110 *((__global DATA_TYPE *)dst_addr) = out3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001111 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001112 *((__global DATA_TYPE *)dst_addr) = out4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001113 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001114 *((__global DATA_TYPE *)dst_addr) = out5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001115 dst_addr += dst_plane_stride;
1116
Giorgio Arena149fdf32018-07-04 17:03:33 +01001117#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001118 // Row1
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001119 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001120 // Row1 can never be out of bounds
1121 valid_y0 = y_coord0;
1122 valid_y1 = y_coord1;
1123
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001124 DATA_TYPE d10 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1125 DATA_TYPE d11 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1126 DATA_TYPE d12 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1127 DATA_TYPE d13 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1128 DATA_TYPE d14 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1129 DATA_TYPE d15 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001130
1131 // Row3
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001132 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
1133 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
1134 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
1135 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
1136 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
1137 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1138 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001139
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001140 DATA_TYPE d30 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1141 DATA_TYPE d31 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1142 DATA_TYPE d32 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1143 DATA_TYPE d33 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1144 DATA_TYPE d34 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1145 DATA_TYPE d35 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001146
1147 // Compute common parts for the channels between [6, 29]
1148 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
1149 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001150 DATA_TYPE part0 = -16.0f * d20 + 20.0f * d22 - 4.0f * d24;
1151 DATA_TYPE part1 = 16.0f * d10 - 20.0f * d12 + 4.0f * d14 - 4.0f * d30 + 5.0f * d32 - d34;
1152 DATA_TYPE part2 = 16.0f * d22 - 4.0f * d24;
1153 DATA_TYPE part3 = 16.0f * d21 - 4.0f * d23;
1154 DATA_TYPE part4 = 16.0f * d12 - 4.0f * d14 - 4.0f * d32 + d34;
1155 DATA_TYPE part5 = 16.0f * d11 - 4.0f * d13 - 4.0f * d31 + d33;
1156 DATA_TYPE part6 = 4.0f * d22 - 4.0f * d24;
1157 DATA_TYPE part7 = 8.0f * d11 - 8.0f * d13 - 2.0f * d31 + 2.0f * d33;
1158 DATA_TYPE part8 = 4.0f * d12 - 4.0f * d14 - d32 + d34;
1159 DATA_TYPE part9 = 8.0f * d21 - 8.0f * d23;
1160 DATA_TYPE part10 = -16.0f * d21 + 20.0f * d23 - 4.0f * d25;
1161 DATA_TYPE part11 = -16.0f * d11 + 20.0f * d13 - 4.0f * d15 + 4.0f * d31 - 5.0f * d33 + d35;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001162
1163 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
1164 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001165 DATA_TYPE part12 = 8.0f * d10 - 10.0f * d12 + 2.0f * d14 - 8.0f * d30 + 10.0f * d32 - 2.0f * d34;
1166 DATA_TYPE part13 = part0 * 0.25f; // -4.0f * d20 + 5.0f * d22 - d24
1167 DATA_TYPE part14 = part2 * 0.25f; // 4.0f * d22 - d24
1168 DATA_TYPE part15 = 8.0f * d11 - 2.0f * d13 - 8.0f * d31 + 2.0f * d33;
1169 DATA_TYPE part16 = 8.0f * d12 - 2.0f * d14 - 8.0f * d32 + 2.0f * d34;
1170 DATA_TYPE part17 = part3 * 0.25f; // 4.0f * d21 - d23
1171 DATA_TYPE part18 = part6 * 0.25f; // d22 - d24
1172 DATA_TYPE part19 = 4.0f * d11 - 4.0f * d13 - 4.0f * d31 + 4.0f * d33;
1173 DATA_TYPE part20 = 2.0f * d12 - 2.0f * d14 - 2.0f * d32 + 2.0f * d34;
1174 DATA_TYPE part21 = part9 * 0.25f; // 2.0f * (d21 - d23)
1175 DATA_TYPE part22 = part10 * 0.25f; // - 4.0f * d21 + 5.0f * d23 - d25
1176 DATA_TYPE part23 = part11 * 0.5f + 6.0f * d31 - 7.5f * d33 + 1.5f * d35; // - 8.0f * d11 + 10.0f * d13 - 2.0f * d15 + 8.0f * d31 - 10.0f * d33 + 2.0f * d35;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001177
1178 out6 += part0 - part1;
1179 out12 += part0 + part1;
1180 out7 += part2 + part3 + part4 + part5;
1181 out8 += part2 - part3 + part4 - part5;
1182 out13 += part2 + part3 - part4 - part5;
1183 out14 += part2 - part3 - part4 + part5;
1184 out9 += part6 + part7 + part8 + part9;
1185 out10 += part6 - part7 + part8 - part9;
1186 out15 += part6 - part7 - part8 + part9;
1187 out16 += part6 + part7 - part8 - part9;
1188 out11 += part10 + part11;
1189 out17 += part10 - part11;
1190
1191 out18 += part13 - part12;
1192 out24 += part13 + part12;
1193 out19 += part14 + part15 + part16 + part17;
1194 out20 += part14 - part15 + part16 - part17;
1195 out25 += part14 - part15 - part16 + part17;
1196 out26 += part14 + part15 - part16 - part17;
1197 out21 += part18 + part19 + part20 + part21;
1198 out22 += part18 - part19 + part20 - part21;
1199 out27 += part18 - part19 - part20 + part21;
1200 out28 += part18 + part19 - part20 - part21;
1201 out23 += part22 + part23;
1202 out29 += part22 - part23;
1203
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001204 *((__global DATA_TYPE *)dst_addr) = out6;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001205 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001206 *((__global DATA_TYPE *)dst_addr) = out7;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001207 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001208 *((__global DATA_TYPE *)dst_addr) = out8;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001209 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001210 *((__global DATA_TYPE *)dst_addr) = out9;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001211 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001212 *((__global DATA_TYPE *)dst_addr) = out10;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001213 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001214 *((__global DATA_TYPE *)dst_addr) = out11;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001215 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001216 *((__global DATA_TYPE *)dst_addr) = out12;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001217 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001218 *((__global DATA_TYPE *)dst_addr) = out13;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001219 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001220 *((__global DATA_TYPE *)dst_addr) = out14;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001221 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001222 *((__global DATA_TYPE *)dst_addr) = out15;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001223 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001224 *((__global DATA_TYPE *)dst_addr) = out16;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001225 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001226 *((__global DATA_TYPE *)dst_addr) = out17;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001227 dst_addr += dst_plane_stride;
1228
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001229 *((__global DATA_TYPE *)dst_addr) = out18;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001230 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001231 *((__global DATA_TYPE *)dst_addr) = out19;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001232 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001233 *((__global DATA_TYPE *)dst_addr) = out20;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001234 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001235 *((__global DATA_TYPE *)dst_addr) = out21;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001236 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001237 *((__global DATA_TYPE *)dst_addr) = out22;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001238 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001239 *((__global DATA_TYPE *)dst_addr) = out23;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001240 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001241 *((__global DATA_TYPE *)dst_addr) = out24;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001242 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001243 *((__global DATA_TYPE *)dst_addr) = out25;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001244 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001245 *((__global DATA_TYPE *)dst_addr) = out26;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001246 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001247 *((__global DATA_TYPE *)dst_addr) = out27;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001248 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001249 *((__global DATA_TYPE *)dst_addr) = out28;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001250 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001251 *((__global DATA_TYPE *)dst_addr) = out29;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001252 dst_addr += dst_plane_stride;
1253
1254 // Row5
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001255 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
1256 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
1257 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
1258 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
1259 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
1260 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1261 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001262
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001263 DATA_TYPE d50 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1264 DATA_TYPE d51 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1265 DATA_TYPE d52 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1266 DATA_TYPE d53 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1267 DATA_TYPE d54 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1268 DATA_TYPE d55 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001269
1270 // Channels [30, 35]
1271 out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34;
1272 out1 = -16.0f * d11 - 16.0f * d12 + 4.0f * d13 + 20.0f * d31 + 20.0f * d32 - 5.0f * d33 - 4.0f * d51 - 4.0f * d52 + d53 + d54 + 4.0f * d14 - 5.0f * d34;
1273 out2 = 16.0f * d11 - 16.0f * d12 - 4.0f * d13 - 20.0f * d31 + 20.0f * d32 + 5.0f * d33 + 4.0f * d51 - 4.0f * d52 - d53 + d54 + 4.0f * d14 - 5.0f * d34;
1274 out3 = -8.0f * d11 - 4.0f * d12 + 8.0f * d13 + 10.0f * d31 - 10.0f * d33 + 5.0f * d32 - 2.0f * d51 + 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
1275 out4 = 8.0f * d11 - 4.0f * d12 - 8.0f * d13 - 10.0f * d31 + 5.0f * d32 + 10.0f * d33 + 2.0f * d51 - 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
1276 out5 = 16.0f * d11 - 20.0f * d13 + 4.0f * d15 - 20.0f * d31 + 25.0f * d33 - 5.0f * d35 + 4.0f * d51 - 5.0f * d53 + d55;
1277
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001278 *((__global DATA_TYPE *)dst_addr) = out0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001279 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001280 *((__global DATA_TYPE *)dst_addr) = out1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001281 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001282 *((__global DATA_TYPE *)dst_addr) = out2;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001283 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001284 *((__global DATA_TYPE *)dst_addr) = out3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001285 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001286 *((__global DATA_TYPE *)dst_addr) = out4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001287 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001288 *((__global DATA_TYPE *)dst_addr) = out5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001289 dst_addr += dst_plane_stride;
Giorgio Arena149fdf32018-07-04 17:03:33 +01001290#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001291}
1292
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001293/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NHWC
Giorgio Arena149fdf32018-07-04 17:03:33 +01001294 *
1295 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1296 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001297 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1298 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001299 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
Giorgio Arena149fdf32018-07-04 17:03:33 +01001300 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001301 * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1302 * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001303 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arena149fdf32018-07-04 17:03:33 +01001304 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001305 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arena149fdf32018-07-04 17:03:33 +01001306 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1307 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1308 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1309 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1310 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1311 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1312 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1313 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1314 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1315 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1316 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1317 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1318 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1319 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1320 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001321 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1322 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001323 */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001324__kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
Giorgio Arena149fdf32018-07-04 17:03:33 +01001325 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001326 TENSOR3D_DECLARATION(dst),
1327 uint src_stride_w,
1328 uint dst_stride_w)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001329{
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001330 const int x = get_global_id(0);
1331 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001332#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001333 const int z = get_global_id(2) % NUM_TILES_Y;
1334 const int b = get_global_id(2) / NUM_TILES_Y;
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001335#else /* defined(NUM_TILES_Y) */
1336 const int z = get_global_id(2);
1337#endif /* defined(NUM_TILES_Y) */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001338
1339 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001340#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001341 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001342#else /* defined(NUM_TILES_Y) */
1343 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
1344#endif /* defined(NUM_TILES_Y) */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001345
1346#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1347 // Clamp coordinates. This clamp is valid for all rows
1348 int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001349 y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001350
1351 // Row0
1352 // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
1353 int z_coord = z * OUTPUT_TILE_H;
1354
1355 // Load the input tile
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001356 VEC_DATA_TYPE(DATA_TYPE, 8)
1357 in_row0;
1358 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord.s0 * (int)src_stride_y + z_coord * src_stride_z);
1359 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord.s1 * (int)src_stride_y + z_coord * src_stride_z);
1360 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord.s2 * (int)src_stride_y + z_coord * src_stride_z);
1361 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord.s3 * (int)src_stride_y + z_coord * src_stride_z);
1362 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord.s4 * (int)src_stride_y + z_coord * src_stride_z);
1363 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord.s5 * (int)src_stride_y + z_coord * src_stride_z);
1364 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord.s6 * (int)src_stride_y + z_coord * src_stride_z);
1365 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001366
1367 // Calculate common factors for intermediate tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001368 VEC_DATA_TYPE(DATA_TYPE, 8)
1369 comm_fact0 = 0.0f;
1370 VEC_DATA_TYPE(DATA_TYPE, 8)
1371 tmp0 = in_row0;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001372
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001373 VEC_DATA_TYPE(DATA_TYPE, 8)
1374 out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001375
1376 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1377
1378#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1379 // We can skip the border clamping along the y dimension as we cannot read out-of-bound in case of 1x5 kernels
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001380 int y_coord = y * (int)OUTPUT_TILE_W;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001381
1382 // Row0
1383 // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
1384 int8 z_coord = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001385 int8 valid_y = select((int8)y_coord, (int8) - 1, z_coord < (int8)0); // If z < 0, set y to -1
1386 valid_y = select(valid_y, (int8)SRC_DIM_1, z_coord >= (int8)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1387 z_coord = clamp(z_coord, (int8)0, (int8)SRC_DIM_2 - 1); // Clamp z coordinate
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001388
1389 // Load the input tile
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001390 VEC_DATA_TYPE(DATA_TYPE, 8)
1391 in_row0;
1392 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord.s0 * src_stride_z);
1393 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord.s1 * src_stride_z);
1394 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord.s2 * src_stride_z);
1395 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord.s3 * src_stride_z);
1396 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord.s4 * src_stride_z);
1397 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord.s5 * src_stride_z);
1398 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord.s6 * src_stride_z);
1399 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord.s7 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001400
1401 // Calculate common factors for intermediate tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001402 VEC_DATA_TYPE(DATA_TYPE, 8)
1403 comm_fact0 = 0.0f;
1404 VEC_DATA_TYPE(DATA_TYPE, 8)
1405 tmp0 = in_row0;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001406
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001407 VEC_DATA_TYPE(DATA_TYPE, 8)
1408 out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001409
1410 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1411#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001412 VEC_DATA_TYPE(DATA_TYPE, 8)
1413 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001414
1415 // Clamp coordinates. This clamp is valid for all rows
1416 int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001417 y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001418
1419 // Row0
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001420 int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
1421 int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1
1422 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1423 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); // Clamp z coordinate
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001424
1425 // Load the input tile
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001426 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1427 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1428 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1429 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1430 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1431 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1432 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1433 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001434
1435 // Row1
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001436 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
1437 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1438 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1439 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001440
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001441 in_row1.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1442 in_row1.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1443 in_row1.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1444 in_row1.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1445 in_row1.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1446 in_row1.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1447 in_row1.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1448 in_row1.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001449
1450 // Row2
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001451 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
1452 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1453 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1454 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001455
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001456 in_row2.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1457 in_row2.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1458 in_row2.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1459 in_row2.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1460 in_row2.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1461 in_row2.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1462 in_row2.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1463 in_row2.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001464
1465 // Row3
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001466 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
1467 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1468 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1469 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001470
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001471 in_row3.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1472 in_row3.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1473 in_row3.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1474 in_row3.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1475 in_row3.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1476 in_row3.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1477 in_row3.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1478 in_row3.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001479
1480 // Row4
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001481 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 4;
1482 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1483 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1484 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001485
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001486 in_row4.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1487 in_row4.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1488 in_row4.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1489 in_row4.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1490 in_row4.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1491 in_row4.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1492 in_row4.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1493 in_row4.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001494
1495 // Row5
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001496 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
1497 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1498 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1499 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001500
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001501 in_row5.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1502 in_row5.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1503 in_row5.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1504 in_row5.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1505 in_row5.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1506 in_row5.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1507 in_row5.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1508 in_row5.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001509
1510 // Row6
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001511 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 6;
1512 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1513 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1514 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001515
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001516 in_row6.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1517 in_row6.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1518 in_row6.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1519 in_row6.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1520 in_row6.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1521 in_row6.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1522 in_row6.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1523 in_row6.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001524
1525 // Row7
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001526 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 7;
1527 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1528 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1529 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001530
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001531 in_row7.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1532 in_row7.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1533 in_row7.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1534 in_row7.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1535 in_row7.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1536 in_row7.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1537 in_row7.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1538 in_row7.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001539
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001540 VEC_DATA_TYPE(DATA_TYPE, 8)
1541 comm_fact0 = in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4;
1542 VEC_DATA_TYPE(DATA_TYPE, 8)
1543 comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3;
1544 VEC_DATA_TYPE(DATA_TYPE, 8)
1545 comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001546
1547 // Calculate intermediate tensor and reuse common factor vectors
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001548 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp0 = in_row0 - in_row6 + (DATA_TYPE)5.25f * in_row4 - (DATA_TYPE)5.25f * in_row2;
1549 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
1550 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001551
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001552 comm_fact0 = (DATA_TYPE)2.5f * in_row3;
1553 comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.f * in_row5;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001554
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001555 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
1556 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001557
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001558 comm_fact1 = (DATA_TYPE)2.f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5;
1559 comm_fact2 = (DATA_TYPE)4.f * in_row2 - (DATA_TYPE)5.f * in_row4 + in_row6;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001560
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001561 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
1562 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
1563 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * in_row3 - (DATA_TYPE)5.25f * in_row5;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001564
1565 // Calculate output rows (reuse comm_fact0 vector)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001566 VEC_DATA_TYPE(DATA_TYPE, 8)
1567 out0, out1, out2, out3, out4, out5, out6, out7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001568 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1569 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
1570 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
1571 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
1572 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
1573 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
1574 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
1575 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001576#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001577
1578 // Store values across the channels
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001579#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001580 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001581#else /* NUM_TILES_Y */
1582 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
1583#endif /* NUM_TILES_Y */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001584
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001585 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1586 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1587 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1588 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1589 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1590 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1591 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1592 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001593
1594#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001595 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1596 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1597 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1598 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1599 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1600 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1601 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1602 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1603 *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1604 *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1605 *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1606 *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1607 *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1608 *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1609 *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1610 *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1611 *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1612 *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1613 *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1614 *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1615 *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1616 *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1617 *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1618 *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1619 *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1620 *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1621 *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1622 *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1623 *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1624 *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1625 *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1626 *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1627 *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1628 *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1629 *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1630 *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1631 *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1632 *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1633 *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1634 *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1635 *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1636 *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1637 *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1638 *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1639 *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1640 *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1641 *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1642 *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1643 *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1644 *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1645 *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1646 *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1647 *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1648 *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1649 *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1650 *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001651#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001652}
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001653#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001654
1655#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1656/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
1657 *
1658 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1659 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1660 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1661 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1662 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001663 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001664 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001665 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001666 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1667 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1668 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1669 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1670 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1671 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1672 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1673 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1674 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1675 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1676 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1677 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1678 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1679 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1680 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001681 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1682 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001683 */
1684__kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
1685 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001686 TENSOR3D_DECLARATION(dst),
1687 uint src_stride_w,
1688 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001689{
1690 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
1691 src_stride_x,
1692 src_step_x,
1693 src_stride_y,
1694 src_step_y,
1695 src_stride_z,
1696 src_step_z,
1697 src_offset_first_element_in_bytes,
1698 dst_ptr,
1699 dst_stride_x,
1700 dst_step_x,
1701 dst_stride_y,
1702 dst_step_y,
1703 dst_stride_z,
1704 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001705 dst_offset_first_element_in_bytes,
1706 src_stride_w,
1707 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001708}
1709
1710/** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
1711 *
1712 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1713 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1714 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1715 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1716 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001717 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001718 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001719 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001720 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1721 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1722 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1723 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1724 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1725 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1726 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1727 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1728 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1729 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1730 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1731 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1732 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1733 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1734 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001735 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1736 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001737 */
1738__kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
1739 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001740 TENSOR3D_DECLARATION(dst),
1741 uint src_stride_w,
1742 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001743{
1744 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
1745 src_stride_x,
1746 src_step_x,
1747 src_stride_y,
1748 src_step_y,
1749 src_stride_z,
1750 src_step_z,
1751 src_offset_first_element_in_bytes,
1752 dst_ptr,
1753 dst_stride_x,
1754 dst_step_x,
1755 dst_stride_y,
1756 dst_step_y,
1757 dst_stride_z,
1758 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001759 dst_offset_first_element_in_bytes,
1760 src_stride_w,
1761 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001762}
1763
1764/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
1765 *
1766 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1767 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1768 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1769 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1770 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001771 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001772 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001773 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001774 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1775 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1776 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1777 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1778 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1779 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1780 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1781 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1782 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1783 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1784 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1785 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1786 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1787 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1788 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001789 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1790 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001791 */
1792__kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
1793 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001794 TENSOR3D_DECLARATION(dst),
1795 uint src_stride_w,
1796 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001797{
1798 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
1799 src_stride_x,
1800 src_step_x,
1801 src_stride_y,
1802 src_step_y,
1803 src_stride_z,
1804 src_step_z,
1805 src_offset_first_element_in_bytes,
1806 dst_ptr,
1807 dst_stride_x,
1808 dst_step_x,
1809 dst_stride_y,
1810 dst_step_y,
1811 dst_stride_z,
1812 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001813 dst_offset_first_element_in_bytes,
1814 src_stride_w,
1815 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001816}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001817
1818/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 when the data layout is NCHW
1819 *
1820 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1821 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1822 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1823 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1824 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001825 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001826 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001827 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001828 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1829 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1830 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1831 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1832 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1833 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1834 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1835 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1836 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1837 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1838 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1839 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1840 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1841 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1842 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001843 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1844 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001845 */
1846__kernel void winograd_input_transform_4x1_5x1_stepz1_nchw(
1847 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001848 TENSOR3D_DECLARATION(dst),
1849 uint src_stride_w,
1850 uint dst_stride_w)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001851{
1852 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
1853 src_stride_x,
1854 src_step_x,
1855 src_stride_y,
1856 src_step_y,
1857 src_stride_z,
1858 src_step_z,
1859 src_offset_first_element_in_bytes,
1860 dst_ptr,
1861 dst_stride_x,
1862 dst_step_x,
1863 dst_stride_y,
1864 dst_step_y,
1865 dst_stride_z,
1866 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001867 dst_offset_first_element_in_bytes,
1868 src_stride_w,
1869 dst_stride_w);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001870}
1871
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001872#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001873/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1 for data layout NHWC
1874 *
1875 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1876 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1877 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
1878 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1879 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1880 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1881 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001882 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001883 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001884 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001885 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1886 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1887 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1888 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1889 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1890 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1891 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1892 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1893 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1894 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1895 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1896 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1897 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1898 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1899 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001900 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1901 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001902 */
1903__kernel void winograd_input_transform_4x1_3x1_stepz1_nhwc(
1904 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001905 TENSOR3D_DECLARATION(dst),
1906 uint src_stride_w,
1907 uint dst_stride_w)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001908{
1909 winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
1910 src_stride_x,
1911 src_step_x,
1912 src_stride_y,
1913 src_step_y,
1914 src_stride_z,
1915 src_step_z,
1916 src_offset_first_element_in_bytes,
1917 dst_ptr,
1918 dst_stride_x,
1919 dst_step_x,
1920 dst_stride_y,
1921 dst_step_y,
1922 dst_stride_z,
1923 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001924 dst_offset_first_element_in_bytes,
1925 src_stride_w,
1926 dst_stride_w);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001927}
1928
1929/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 for data layout NHWC
1930 *
1931 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1932 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1933 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
1934 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1935 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1936 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1937 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001938 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001939 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001940 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001941 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1942 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1943 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1944 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1945 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1946 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1947 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1948 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1949 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1950 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1951 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1952 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1953 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1954 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1955 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001956 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1957 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001958 */
1959__kernel void winograd_input_transform_4x1_5x1_stepz1_nhwc(
1960 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001961 TENSOR3D_DECLARATION(dst),
1962 uint src_stride_w,
1963 uint dst_stride_w)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001964{
1965 winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
1966 src_stride_x,
1967 src_step_x,
1968 src_stride_y,
1969 src_step_y,
1970 src_stride_z,
1971 src_step_z,
1972 src_offset_first_element_in_bytes,
1973 dst_ptr,
1974 dst_stride_x,
1975 dst_step_x,
1976 dst_stride_y,
1977 dst_step_y,
1978 dst_stride_z,
1979 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001980 dst_offset_first_element_in_bytes,
1981 src_stride_w,
1982 dst_stride_w);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001983}
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001984#endif // defined(NUM_TILES_Y) && defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001985#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1986
1987#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1988/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
1989 *
1990 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1991 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1992 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1993 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1994 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001995 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001996 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001997 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001998 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1999 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2000 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2001 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2002 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2003 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2004 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2005 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2006 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2007 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2008 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2009 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2010 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2011 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2012 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002013 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2014 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002015 */
2016__kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
2017 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002018 TENSOR3D_DECLARATION(dst),
2019 uint src_stride_w,
2020 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002021{
2022 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
2023 src_stride_x,
2024 src_step_x,
2025 src_stride_y,
2026 src_step_y,
2027 src_stride_z,
2028 src_step_z,
2029 src_offset_first_element_in_bytes,
2030 dst_ptr,
2031 dst_stride_x,
2032 dst_step_x,
2033 dst_stride_y,
2034 dst_step_y,
2035 dst_stride_z,
2036 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002037 dst_offset_first_element_in_bytes,
2038 src_stride_w,
2039 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002040}
2041
2042/** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
2043 *
2044 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2045 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2046 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2047 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2048 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002049 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002050 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002051 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002052 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2053 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2054 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2055 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2056 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2057 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2058 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2059 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2060 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2061 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2062 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2063 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2064 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2065 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2066 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002067 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2068 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002069 */
2070__kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
2071 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002072 TENSOR3D_DECLARATION(dst),
2073 uint src_stride_w,
2074 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002075{
2076 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
2077 src_stride_x,
2078 src_step_x,
2079 src_stride_y,
2080 src_step_y,
2081 src_stride_z,
2082 src_step_z,
2083 src_offset_first_element_in_bytes,
2084 dst_ptr,
2085 dst_stride_x,
2086 dst_step_x,
2087 dst_stride_y,
2088 dst_step_y,
2089 dst_stride_z,
2090 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002091 dst_offset_first_element_in_bytes,
2092 src_stride_w,
2093 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002094}
2095
2096/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
2097 *
2098 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2099 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2100 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2101 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2102 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002103 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002104 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002105 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002106 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2107 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2108 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2109 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2110 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2111 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2112 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2113 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2114 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2115 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2116 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2117 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2118 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2119 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2120 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002121 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2122 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002123 */
2124__kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
2125 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002126 TENSOR3D_DECLARATION(dst),
2127 uint src_stride_w,
2128 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002129{
2130 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
2131 src_stride_x,
2132 src_step_x,
2133 src_stride_y,
2134 src_step_y,
2135 src_stride_z,
2136 src_step_z,
2137 src_offset_first_element_in_bytes,
2138 dst_ptr,
2139 dst_stride_x,
2140 dst_step_x,
2141 dst_stride_y,
2142 dst_step_y,
2143 dst_stride_z,
2144 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002145 dst_offset_first_element_in_bytes,
2146 src_stride_w,
2147 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002148}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002149
2150/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4
2151 *
2152 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2153 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2154 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2155 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2156 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002157 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002158 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002159 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002160 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2161 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2162 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2163 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2164 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2165 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2166 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2167 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2168 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2169 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2170 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2171 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2172 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2173 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2174 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002175 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2176 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002177 */
2178__kernel void winograd_input_transform_1x4_1x5_stepz1_nchw(
2179 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002180 TENSOR3D_DECLARATION(dst),
2181 uint src_stride_w,
2182 uint dst_stride_w)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002183{
2184 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
2185 src_stride_x,
2186 src_step_x,
2187 src_stride_y,
2188 src_step_y,
2189 src_stride_z,
2190 src_step_z,
2191 src_offset_first_element_in_bytes,
2192 dst_ptr,
2193 dst_stride_x,
2194 dst_step_x,
2195 dst_stride_y,
2196 dst_step_y,
2197 dst_stride_z,
2198 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002199 dst_offset_first_element_in_bytes,
2200 src_stride_w,
2201 dst_stride_w);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002202}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002203
Georgios Pinitasffb57a02018-10-29 18:01:52 +00002204#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002205/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4 for data layout NHWC
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002206 *
2207 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002208 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2209 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002210 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002211 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002212 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002213 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002214 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002215 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002216 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002217 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2218 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2219 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2220 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2221 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2222 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2223 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2224 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2225 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2226 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2227 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2228 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2229 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2230 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2231 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002232 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2233 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002234 */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002235__kernel void winograd_input_transform_1x4_1x3_stepz1_nhwc(
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002236 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002237 TENSOR3D_DECLARATION(dst),
2238 uint src_stride_w,
2239 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002240{
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002241 winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
2242 src_stride_x,
2243 src_step_x,
2244 src_stride_y,
2245 src_step_y,
2246 src_stride_z,
2247 src_step_z,
2248 src_offset_first_element_in_bytes,
2249 dst_ptr,
2250 dst_stride_x,
2251 dst_step_x,
2252 dst_stride_y,
2253 dst_step_y,
2254 dst_stride_z,
2255 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002256 dst_offset_first_element_in_bytes,
2257 src_stride_w,
2258 dst_stride_w);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002259}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002260
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002261/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4 for data layout NHWC
2262 *
2263 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2264 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2265 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2266 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2267 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2268 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2269 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002270 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002271 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002272 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002273 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2274 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2275 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2276 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2277 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2278 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2279 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2280 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2281 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2282 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2283 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2284 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2285 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2286 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2287 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002288 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2289 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002290 */
2291__kernel void winograd_input_transform_1x4_1x5_stepz1_nhwc(
2292 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002293 TENSOR3D_DECLARATION(dst),
2294 uint src_stride_w,
2295 uint dst_stride_w)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002296{
2297 winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
2298 src_stride_x,
2299 src_step_x,
2300 src_stride_y,
2301 src_step_y,
2302 src_stride_z,
2303 src_step_z,
2304 src_offset_first_element_in_bytes,
2305 dst_ptr,
2306 dst_stride_x,
2307 dst_step_x,
2308 dst_stride_y,
2309 dst_step_y,
2310 dst_stride_z,
2311 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002312 dst_offset_first_element_in_bytes,
2313 src_stride_w,
2314 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002315}
Georgios Pinitasffb57a02018-10-29 18:01:52 +00002316#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002317#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002318#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)