blob: 630a78b12f7dcfc28be60f168d1315a0a8bedc0b [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
Michele Di Giorgiof955d512019-02-27 14:26:51 +00002 * Copyright (c) 2018-2019 ARM Limited.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodiced28b7512018-07-06 12:59:28 +010026#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
27 ({ \
28 comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
29 comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
30 comm_fact.s2 = 2.5f * tmp.s3; \
31 comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
32 comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
33 comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
34 comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
35 \
36 out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
37 out.s1 = comm_fact.s0 + comm_fact.s1; \
38 out.s2 = comm_fact.s0 - comm_fact.s1; \
39 out.s3 = comm_fact.s3 + comm_fact.s4; \
40 out.s4 = comm_fact.s4 - comm_fact.s3; \
41 out.s5 = comm_fact.s5 + comm_fact.s6; \
42 out.s6 = comm_fact.s5 - comm_fact.s6; \
43 out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
44 })
45
Michele Di Giorgiof955d512019-02-27 14:26:51 +000046#define OUTPUT_ROW_2x2_7x7(out, tmp, comm_fact) \
47 ({ \
48 comm_fact.s0 = 36.0f * tmp.s2 - 13.0f * tmp.s4 + tmp.s6; \
49 comm_fact.s1 = 36.0f * tmp.s1 - 13.0f * tmp.s3 + 1.0f * tmp.s5; \
50 comm_fact.s2 = 9.0f * tmp.s2 - 10.0f * tmp.s4 + tmp.s6; \
51 comm_fact.s3 = 18.0f * tmp.s1 - 20.0f * tmp.s3 + 2.0f * tmp.s5; \
52 comm_fact.s4 = 4.0f * tmp.s2 - 5.0f * tmp.s4 + tmp.s6; \
53 comm_fact.s5 = 12.0f * tmp.s1 - 15.0f * tmp.s3 + 3.0f * tmp.s5; \
54 out.s0 = -36.0f * tmp.s0 + 49.0f * tmp.s2 + -14.0f * tmp.s4 + tmp.s6; \
55 out.s1 = comm_fact.s0 - comm_fact.s1; \
56 out.s2 = comm_fact.s0 + comm_fact.s1; \
57 out.s3 = comm_fact.s2 - comm_fact.s3; \
58 out.s4 = comm_fact.s2 + comm_fact.s3; \
59 out.s5 = comm_fact.s4 - comm_fact.s5; \
60 out.s6 = comm_fact.s4 + comm_fact.s5; \
61 out.s7 = -36.0f * tmp.s1 + 0.0f * tmp.s2 + 49.0f * tmp.s3 - 14.0f * tmp.s5 + tmp.s7; \
62 })
63
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010064#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
65/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
66 *
67 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
68 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
69 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
70 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
71 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
72 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010073 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010074 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010075 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010076 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
77 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
78 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
79 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
80 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
81 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
82 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
83 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
84 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
85 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
86 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
87 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
88 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
89 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
90 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +010091 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
92 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010093 */
94__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
95 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +010096 TENSOR3D_DECLARATION(dst),
97 uint src_stride_w,
98 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010099{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100100 const int x = get_global_id(0);
101 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000102#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100103 const int z = get_global_id(2) % SRC_DEPTH;
104 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000105#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000106 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000107#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100108
109 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000110#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100111 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000112#else /* defined(SRC_DEPTH) */
113 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
114#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100115
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100116 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100117
118#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100119 VEC_DATA_TYPE(DATA_TYPE, 4)
120 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100121#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100122 VEC_DATA_TYPE(DATA_TYPE, 4)
123 in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
124 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
125 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
126 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100127#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100128 VEC_DATA_TYPE(DATA_TYPE, 4)
129 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
130 VEC_DATA_TYPE(DATA_TYPE, 4)
131 in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
132 VEC_DATA_TYPE(DATA_TYPE, 4)
133 in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
134 VEC_DATA_TYPE(DATA_TYPE, 4)
135 in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100136#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
137
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100138 VEC_DATA_TYPE(DATA_TYPE, 4)
139 tmp0 = in_row0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100140
141#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
142 tmp0 -= in_row2;
143#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
144
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100145 DATA_TYPE out00 = tmp0.s0 - tmp0.s2;
146 DATA_TYPE out01 = tmp0.s1 + tmp0.s2;
147 DATA_TYPE out02 = tmp0.s2 - tmp0.s1;
148 DATA_TYPE out03 = tmp0.s1 - tmp0.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100149
150#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100151 VEC_DATA_TYPE(DATA_TYPE, 4)
152 tmp1 = in_row1 + in_row2;
153 VEC_DATA_TYPE(DATA_TYPE, 4)
154 tmp2 = in_row2 - in_row1;
155 VEC_DATA_TYPE(DATA_TYPE, 4)
156 tmp3 = in_row1 - in_row3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100157
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100158 DATA_TYPE out10 = tmp1.s0 - tmp1.s2;
159 DATA_TYPE out11 = tmp1.s1 + tmp1.s2;
160 DATA_TYPE out12 = tmp1.s2 - tmp1.s1;
161 DATA_TYPE out13 = tmp1.s1 - tmp1.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100162
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100163 DATA_TYPE out20 = tmp2.s0 - tmp2.s2;
164 DATA_TYPE out21 = tmp2.s1 + tmp2.s2;
165 DATA_TYPE out22 = tmp2.s2 - tmp2.s1;
166 DATA_TYPE out23 = tmp2.s1 - tmp2.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100167
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100168 DATA_TYPE out30 = tmp3.s0 - tmp3.s2;
169 DATA_TYPE out31 = tmp3.s1 + tmp3.s2;
170 DATA_TYPE out32 = tmp3.s2 - tmp3.s1;
171 DATA_TYPE out33 = tmp3.s1 - tmp3.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100172#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
173
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000174#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100175 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000176#else /* defined(SRC_DEPTH) */
177 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
178#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100179
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100180 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
181 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
182 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
183 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100184
185#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100186 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out10;
187 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out11;
188 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out12;
189 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out13;
190 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out20;
191 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out21;
192 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out22;
193 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out23;
194 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out30;
195 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out31;
196 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out32;
197 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out33;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100198#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
199}
200
201/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
202 *
203 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
204 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
205 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
206 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
207 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
208 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100209 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100210 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100211 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100212 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
213 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
214 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
215 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
216 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
217 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
218 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
219 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
220 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
221 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
222 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
223 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
224 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
225 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
226 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100227 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
228 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100229 */
230__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
231 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100232 TENSOR3D_DECLARATION(dst),
233 uint src_stride_w,
234 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100235{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100236 const int x = get_global_id(0);
237 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000238#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100239 const int z = (get_global_id(2) * 2) % SRC_DEPTH;
240 const int b = (get_global_id(2) * 2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000241#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000242 const int z = get_global_id(2) * 2;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000243#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100244
245 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000246#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100247 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000248#else /* defined(SRC_DEPTH) */
249 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
250#endif /* defined(SRC_DEPTH) */
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100251 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100252
253#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100254 VEC_DATA_TYPE(DATA_TYPE, 4)
255 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100256#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100257 VEC_DATA_TYPE(DATA_TYPE, 4)
258 in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
259 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
260 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
261 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100262#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100263 VEC_DATA_TYPE(DATA_TYPE, 4)
264 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
265 VEC_DATA_TYPE(DATA_TYPE, 4)
266 in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
267 VEC_DATA_TYPE(DATA_TYPE, 4)
268 in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
269 VEC_DATA_TYPE(DATA_TYPE, 4)
270 in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100271#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
272
273 src_addr += src_stride_z;
274#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100275 VEC_DATA_TYPE(DATA_TYPE, 4)
276 in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100277#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100278 VEC_DATA_TYPE(DATA_TYPE, 4)
279 in_row4 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
280 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
281 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
282 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100283#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100284 VEC_DATA_TYPE(DATA_TYPE, 4)
285 in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
286 VEC_DATA_TYPE(DATA_TYPE, 4)
287 in_row5 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
288 VEC_DATA_TYPE(DATA_TYPE, 4)
289 in_row6 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
290 VEC_DATA_TYPE(DATA_TYPE, 4)
291 in_row7 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100292#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
293
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100294 VEC_DATA_TYPE(DATA_TYPE, 4)
295 tmp0 = in_row0;
296 VEC_DATA_TYPE(DATA_TYPE, 4)
297 tmp4 = in_row4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100298
299#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
300 tmp0 -= in_row2;
301 tmp4 -= in_row6;
302#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
303
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100304 VEC_DATA_TYPE(DATA_TYPE, 2)
305 out00 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
306 VEC_DATA_TYPE(DATA_TYPE, 2)
307 out01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
308 VEC_DATA_TYPE(DATA_TYPE, 2)
309 out02 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
310 VEC_DATA_TYPE(DATA_TYPE, 2)
311 out03 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100312
313#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100314 VEC_DATA_TYPE(DATA_TYPE, 4)
315 tmp1 = in_row1 + in_row2;
316 VEC_DATA_TYPE(DATA_TYPE, 4)
317 tmp2 = in_row2 - in_row1;
318 VEC_DATA_TYPE(DATA_TYPE, 4)
319 tmp3 = in_row1 - in_row3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100320
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100321 VEC_DATA_TYPE(DATA_TYPE, 4)
322 tmp5 = in_row5 + in_row6;
323 VEC_DATA_TYPE(DATA_TYPE, 4)
324 tmp6 = in_row6 - in_row5;
325 VEC_DATA_TYPE(DATA_TYPE, 4)
326 tmp7 = in_row5 - in_row7;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100327
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100328 VEC_DATA_TYPE(DATA_TYPE, 2)
329 out10 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
330 VEC_DATA_TYPE(DATA_TYPE, 2)
331 out11 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
332 VEC_DATA_TYPE(DATA_TYPE, 2)
333 out12 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
334 VEC_DATA_TYPE(DATA_TYPE, 2)
335 out13 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100336
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100337 VEC_DATA_TYPE(DATA_TYPE, 2)
338 out20 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
339 VEC_DATA_TYPE(DATA_TYPE, 2)
340 out21 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
341 VEC_DATA_TYPE(DATA_TYPE, 2)
342 out22 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
343 VEC_DATA_TYPE(DATA_TYPE, 2)
344 out23 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100345
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100346 VEC_DATA_TYPE(DATA_TYPE, 2)
347 out30 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
348 VEC_DATA_TYPE(DATA_TYPE, 2)
349 out31 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
350 VEC_DATA_TYPE(DATA_TYPE, 2)
351 out32 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
352 VEC_DATA_TYPE(DATA_TYPE, 2)
353 out33 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100354#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
355
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000356#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100357 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000358#else /* defined(SRC_DEPTH) */
359 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
360#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100361
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100362 vstore2(out00, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z));
363 vstore2(out01, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z));
364 vstore2(out02, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z));
365 vstore2(out03, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100366
367#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100368 vstore2(out10, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z));
369 vstore2(out11, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z));
370 vstore2(out12, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z));
371 vstore2(out13, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z));
372 vstore2(out20, 0, (__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z));
373 vstore2(out21, 0, (__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z));
374 vstore2(out22, 0, (__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z));
375 vstore2(out23, 0, (__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z));
376 vstore2(out30, 0, (__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z));
377 vstore2(out31, 0, (__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z));
378 vstore2(out32, 0, (__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z));
379 vstore2(out33, 0, (__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100380#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
381}
382
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100383/** This OpenCL kernel computes the input transform when the output tile is 4x4/4x1 or 1x4, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100384 *
385 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
386 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
387 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
388 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
389 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
390 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100391 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100392 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100393 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100394 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
395 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
396 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
397 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
398 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
399 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
400 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
401 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
402 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
403 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
404 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
405 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
406 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
407 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
408 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100409 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
410 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100411 */
412__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
413 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100414 TENSOR3D_DECLARATION(dst),
415 uint src_stride_w,
416 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100417{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100418 const int x = get_global_id(0);
419 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000420#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100421 const int z = get_global_id(2) % SRC_DEPTH;
422 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000423#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000424 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000425#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100426
427 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000428#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100429 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000430#else /* defined(SRC_DEPTH) */
431 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
432#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100433
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100434 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100435
436#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
437 // Row0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100438 VEC_DATA_TYPE(DATA_TYPE, 4)
439 d00 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
440 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
441 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
442 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
443 VEC_DATA_TYPE(DATA_TYPE, 2)
444 d01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(*((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
445 *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100446#else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
447 // Row0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100448 VEC_DATA_TYPE(DATA_TYPE, 4)
449 d00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
450 VEC_DATA_TYPE(DATA_TYPE, 2)
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000451 d01 = vload2(2, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100452#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
453
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100454 DATA_TYPE out0 = 0.0f;
455 DATA_TYPE out1 = 0.0f;
456 DATA_TYPE out2 = 0.0f;
457 DATA_TYPE out3 = 0.0f;
458 DATA_TYPE out4 = 0.0f;
459 DATA_TYPE out5 = 0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100460
461 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
462 out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
463 out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
464 out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
465 out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
466 out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
467 out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
468
469#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
470 // Row4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100471 VEC_DATA_TYPE(DATA_TYPE, 4)
472 d40 = vload4(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
473 VEC_DATA_TYPE(DATA_TYPE, 2)
474 d41 = vload2(2, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100475
476 // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100477 DATA_TYPE k0 = d41.s0;
478 DATA_TYPE k1 = d41.s0;
479 DATA_TYPE k2 = d41.s0;
480 DATA_TYPE k3 = d41.s0;
481 DATA_TYPE k4 = d41.s0;
482 DATA_TYPE k5 = 0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100483
484 k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
485 k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
486 k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
487 k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
488 k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
489 k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
490
491 out0 += k0;
492 out1 += k1;
493 out2 += k2;
494 out3 += k3;
495 out4 += k4;
496 out5 += k5;
497
498 // Row2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100499 VEC_DATA_TYPE(DATA_TYPE, 4)
500 d20 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
501 VEC_DATA_TYPE(DATA_TYPE, 2)
502 d21 = vload2(2, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100503
504 out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
505 out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
506 out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
507 out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
508 out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
509 out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
510#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
511
512 // Compute destination address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000513#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100514 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000515#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000516 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000517#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100518
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100519 uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100520
521 *(dst_addr) = out0;
522 dst_addr += dst_plane_stride;
523 *(dst_addr) = out1;
524 dst_addr += dst_plane_stride;
525 *(dst_addr) = out2;
526 dst_addr += dst_plane_stride;
527 *(dst_addr) = out3;
528 dst_addr += dst_plane_stride;
529 *(dst_addr) = out4;
530 dst_addr += dst_plane_stride;
531 *(dst_addr) = out5;
532 dst_addr += dst_plane_stride;
533
534#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100535 DATA_TYPE out6 = k0;
536 DATA_TYPE out7 = k1;
537 DATA_TYPE out8 = k2;
538 DATA_TYPE out9 = k3;
539 DATA_TYPE out10 = k4;
540 DATA_TYPE out11 = k5;
541 DATA_TYPE out12 = k0;
542 DATA_TYPE out13 = k1;
543 DATA_TYPE out14 = k2;
544 DATA_TYPE out15 = k3;
545 DATA_TYPE out16 = k4;
546 DATA_TYPE out17 = k5;
547 DATA_TYPE out18 = k0;
548 DATA_TYPE out19 = k1;
549 DATA_TYPE out20 = k2;
550 DATA_TYPE out21 = k3;
551 DATA_TYPE out22 = k4;
552 DATA_TYPE out23 = k5;
553 DATA_TYPE out24 = k0;
554 DATA_TYPE out25 = k1;
555 DATA_TYPE out26 = k2;
556 DATA_TYPE out27 = k3;
557 DATA_TYPE out28 = k4;
558 DATA_TYPE out29 = k5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100559
560 // Row1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100561 VEC_DATA_TYPE(DATA_TYPE, 4)
562 d10 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
563 VEC_DATA_TYPE(DATA_TYPE, 2)
564 d11 = vload2(2, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100565
566 // Row3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100567 VEC_DATA_TYPE(DATA_TYPE, 4)
568 d30 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
569 VEC_DATA_TYPE(DATA_TYPE, 2)
570 d31 = vload2(2, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100571
572 // Compute common parts for the channels between [6, 29]
573 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
574 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100575 DATA_TYPE part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
576 DATA_TYPE part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
577 DATA_TYPE part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
578 DATA_TYPE part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
579 DATA_TYPE part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
580 DATA_TYPE part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
581 DATA_TYPE part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
582 DATA_TYPE part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
583 DATA_TYPE part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
584 DATA_TYPE part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
585 DATA_TYPE part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
586 DATA_TYPE part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100587
588 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
589 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100590 DATA_TYPE part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
591 DATA_TYPE part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
592 DATA_TYPE part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
593 DATA_TYPE part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
594 DATA_TYPE part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
595 DATA_TYPE part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
596 DATA_TYPE part18 = part6 * 0.25f; // d20.s2 - d21.s0
597 DATA_TYPE part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
598 DATA_TYPE part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
599 DATA_TYPE part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
600 DATA_TYPE part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
601 DATA_TYPE part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100602
603 out6 += part0 - part1;
604 out12 += part0 + part1;
605 out7 += part2 + part3 + part4 + part5;
606 out8 += part2 - part3 + part4 - part5;
607 out13 += part2 + part3 - part4 - part5;
608 out14 += part2 - part3 - part4 + part5;
609 out9 += part6 + part7 + part8 + part9;
610 out10 += part6 - part7 + part8 - part9;
611 out15 += part6 - part7 - part8 + part9;
612 out16 += part6 + part7 - part8 - part9;
613 out11 += part10 + part11;
614 out17 += part10 - part11;
615
616 out18 += part13 - part12;
617 out24 += part13 + part12;
618 out19 += part14 + part15 + part16 + part17;
619 out20 += part14 - part15 + part16 - part17;
620 out25 += part14 - part15 - part16 + part17;
621 out26 += part14 + part15 - part16 - part17;
622 out21 += part18 + part19 + part20 + part21;
623 out22 += part18 - part19 + part20 - part21;
624 out27 += part18 - part19 - part20 + part21;
625 out28 += part18 + part19 - part20 - part21;
626 out23 += part22 + part23;
627 out29 += part22 - part23;
628
629 *(dst_addr) = out6;
630 dst_addr += dst_plane_stride;
631 *(dst_addr) = out7;
632 dst_addr += dst_plane_stride;
633 *(dst_addr) = out8;
634 dst_addr += dst_plane_stride;
635 *(dst_addr) = out9;
636 dst_addr += dst_plane_stride;
637 *(dst_addr) = out10;
638 dst_addr += dst_plane_stride;
639 *(dst_addr) = out11;
640 dst_addr += dst_plane_stride;
641 *(dst_addr) = out12;
642 dst_addr += dst_plane_stride;
643 *(dst_addr) = out13;
644 dst_addr += dst_plane_stride;
645 *(dst_addr) = out14;
646 dst_addr += dst_plane_stride;
647 *(dst_addr) = out15;
648 dst_addr += dst_plane_stride;
649 *(dst_addr) = out16;
650 dst_addr += dst_plane_stride;
651 *(dst_addr) = out17;
652 dst_addr += dst_plane_stride;
653
654 *(dst_addr) = out18;
655 dst_addr += dst_plane_stride;
656 *(dst_addr) = out19;
657 dst_addr += dst_plane_stride;
658 *(dst_addr) = out20;
659 dst_addr += dst_plane_stride;
660 *(dst_addr) = out21;
661 dst_addr += dst_plane_stride;
662 *(dst_addr) = out22;
663 dst_addr += dst_plane_stride;
664 *(dst_addr) = out23;
665 dst_addr += dst_plane_stride;
666 *(dst_addr) = out24;
667 dst_addr += dst_plane_stride;
668 *(dst_addr) = out25;
669 dst_addr += dst_plane_stride;
670 *(dst_addr) = out26;
671 dst_addr += dst_plane_stride;
672 *(dst_addr) = out27;
673 dst_addr += dst_plane_stride;
674 *(dst_addr) = out28;
675 dst_addr += dst_plane_stride;
676 *(dst_addr) = out29;
677 dst_addr += dst_plane_stride;
678
679 // Row5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100680 VEC_DATA_TYPE(DATA_TYPE, 4)
681 d50 = vload4(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
682 VEC_DATA_TYPE(DATA_TYPE, 2)
683 d51 = vload2(2, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100684
685 // Channels [30, 35]
686 out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
687 out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
688 out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
689 out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
690 out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
691 out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
692
693 *(dst_addr) = out0;
694 dst_addr += dst_plane_stride;
695 *(dst_addr) = out1;
696 dst_addr += dst_plane_stride;
697 *(dst_addr) = out2;
698 dst_addr += dst_plane_stride;
699 *(dst_addr) = out3;
700 dst_addr += dst_plane_stride;
701 *(dst_addr) = out4;
702 dst_addr += dst_plane_stride;
703 *(dst_addr) = out5;
704 dst_addr += dst_plane_stride;
705#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
706}
707
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100708/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NCHW
709 *
710 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
711 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
712 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
713 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
714 * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
715 * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
716 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
717 *
718 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
719 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
720 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
721 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
722 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
723 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
724 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
725 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
726 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
727 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
728 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
729 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
730 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
731 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
732 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
733 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
734 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
735 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
736 */
737__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
738 TENSOR3D_DECLARATION(src),
739 TENSOR3D_DECLARATION(dst),
740 uint src_stride_w,
741 uint dst_stride_w)
742{
743 const int x = get_global_id(0);
744 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000745#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100746 const int z = get_global_id(2) % SRC_DEPTH;
747 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000748#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000749 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000750#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100751
752 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000753#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100754 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000755#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000756 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000757#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100758 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
759
760 // Load input tile
761#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
762 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr));
763#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
764 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 8))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
765 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
766 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
767 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)),
768 *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
769 *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)),
770 *((__global DATA_TYPE *)(src_addr + 6 * src_stride_y)),
771 *((__global DATA_TYPE *)(src_addr + 7 * src_stride_y)));
772#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
773 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
774 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row1 = vload8(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
775 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row2 = vload8(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
776 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row3 = vload8(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
777 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row4 = vload8(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
778 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row5 = vload8(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
779 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row6 = vload8(0, (__global DATA_TYPE *)(src_addr + 6 * src_stride_y));
780 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row7 = vload8(0, (__global DATA_TYPE *)(src_addr + 7 * src_stride_y));
781#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
782
783 // Calculate common factors for intermediate tensor
784 VEC_DATA_TYPE(DATA_TYPE, 8)
785 tmp0 = in_row0;
786 VEC_DATA_TYPE(DATA_TYPE, 8)
787 comm_fact0 = 0.0f;
788
789#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
790 comm_fact0 += in_row2 + in_row6 - (DATA_TYPE)4.25 * in_row4;
791 tmp0 += -in_row6 + (DATA_TYPE)5.25 * in_row4 - (DATA_TYPE)5.25 * in_row2;
792
793 VEC_DATA_TYPE(DATA_TYPE, 8)
794 comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25 * in_row3;
795 VEC_DATA_TYPE(DATA_TYPE, 8)
796 comm_fact2 = (DATA_TYPE)0.25 * in_row2 - (DATA_TYPE)1.25 * in_row4 + in_row6;
797
798 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
799 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
800
801 comm_fact0 = (DATA_TYPE)2.5 * in_row3;
802 comm_fact1 = (DATA_TYPE)0.5 * in_row1 - comm_fact0 + (DATA_TYPE)2.0 * in_row5;
803
804 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
805 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
806
807 comm_fact1 = (DATA_TYPE)2.0 * in_row1 - comm_fact0 + (DATA_TYPE)0.5 * in_row5;
808 comm_fact2 = (DATA_TYPE)4.0 * in_row2 - (DATA_TYPE)5.0 * in_row4 + in_row6;
809
810 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
811 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
812 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25 * in_row3 - (DATA_TYPE)5.25 * in_row5;
813#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
814
815 // Calculate output rows (reuse comm_fact0 vector)
816 VEC_DATA_TYPE(DATA_TYPE, 8)
817 out0;
818
819 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
820
821#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
822 VEC_DATA_TYPE(DATA_TYPE, 8)
823 out1, out2, out3, out4, out5, out6, out7;
824
825 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
826 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
827 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
828 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
829 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
830 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
831 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
832#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
833
834 // Store values across the channels
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000835#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100836 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000837#else /* defined(SRC_DEPTH) */
838 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
839#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100840
841 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
842 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
843 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
844 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
845 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
846 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
847 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
848 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
849
850#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
851 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
852 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
853 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
854 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
855 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
856 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
857 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
858 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
859 *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
860 *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
861 *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
862 *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
863 *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
864 *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
865 *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
866 *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
867 *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
868 *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
869 *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
870 *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
871 *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
872 *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
873 *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
874 *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
875 *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
876 *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
877 *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
878 *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
879 *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
880 *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
881 *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
882 *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
883 *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
884 *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
885 *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
886 *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
887 *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
888 *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
889 *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
890 *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
891 *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
892 *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
893 *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
894 *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
895 *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
896 *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
897 *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
898 *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
899 *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
900 *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
901 *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
902 *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
903 *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
904 *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
905 *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
906 *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
907#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
908}
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100909
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000910#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arena149fdf32018-07-04 17:03:33 +0100911/** This OpenCL kernel computes the input transform when the output tile is 4x4, 4x1 or 1x4, the filter size 3x3, 3x1 or 1x3 and the data layout is NHWC
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100912 *
913 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
914 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
915 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
916 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arena149fdf32018-07-04 17:03:33 +0100917 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
918 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
919 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
920 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100921 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100922 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100923 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100924 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
925 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
926 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
927 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
928 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
929 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
930 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
931 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
932 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
933 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
934 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
935 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
936 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
937 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
938 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100939 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
940 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100941 */
942__kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
943 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100944 TENSOR3D_DECLARATION(dst),
945 uint src_stride_w,
946 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100947{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100948 const int x = get_global_id(0);
949 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000950#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100951 const int z = get_global_id(2) % NUM_TILES_Y;
952 const int b = get_global_id(2) / NUM_TILES_Y;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000953#else /* defined(NUM_TILES_Y) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000954 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000955#endif /* defined(NUM_TILES_Y) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100956
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000957#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100958 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000959#else /* defined(NUM_TILES_Y) */
960 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
961#endif /* defined(NUM_TILES_Y) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100962
963 // Clamp coordinates. This clamp is valid for all rows
Giorgio Arena149fdf32018-07-04 17:03:33 +0100964 int4 y_coord0 = (int4)(y * OUTPUT_TILE_W) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT;
965 int2 y_coord1 = (int2)(y * OUTPUT_TILE_W) + (int2)(4, 5) - (int2)PAD_LEFT;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100966 y_coord0 = clamp(y_coord0, (int4) - 1, (int4)SRC_DIM_1);
967 y_coord1 = clamp(y_coord1, (int2) - 1, (int2)SRC_DIM_1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100968
Giorgio Arena149fdf32018-07-04 17:03:33 +0100969 int z_coord;
970 int4 valid_y0;
971 int2 valid_y1;
972
973#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100974 // Row4
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100975 z_coord = (z * 4) - (int)PAD_TOP + 4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100976
977 // If z < 0, set y to -1
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100978 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
979 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100980 // If z >= SRC_DIM_2, set y to SRC_DIM_2
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100981 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
982 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100983
984 // Clamp z coordinate
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +0100985 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100986
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100987 DATA_TYPE d40 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
988 DATA_TYPE d41 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
989 DATA_TYPE d42 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
990 DATA_TYPE d43 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
991 DATA_TYPE d44 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
992 DATA_TYPE d45 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100993
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100994 DATA_TYPE k0 = d44;
995 DATA_TYPE k1 = d44;
996 DATA_TYPE k2 = d44;
997 DATA_TYPE k3 = d44;
998 DATA_TYPE k4 = d44;
999 DATA_TYPE k5 = (DATA_TYPE)0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001000
1001 k0 += 4.0f * d40 - 5.0f * d42;
1002 k1 += -4.0f * d41 - 4.0f * d42 + d43;
1003 k2 += 4.0f * d41 - 4.0f * d42 - d43;
1004 k3 += -2.0f * d41 + 2.0f * d43 - d42;
1005 k4 += 2.0f * d41 - 2.0f * d43 - d42;
1006 k5 += 4.0f * d41 - 5.0f * d43 + d45;
Giorgio Arena149fdf32018-07-04 17:03:33 +01001007#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001008
Giorgio Arena149fdf32018-07-04 17:03:33 +01001009#if !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001010 // Row0
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001011 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001012
1013#if PAD_TOP != 0
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001014 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
1015 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
1016 valid_y0 = select(valid_y0, (int)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
1017 valid_y1 = select(valid_y1, (int)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
1018 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001019#else // PAD_TOP != 0
1020 valid_y0 = y_coord0;
1021 valid_y1 = y_coord1;
1022#endif // if PAD_TOP == 0, we cannot read out of bound
1023
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001024 DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1025 DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1026 DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1027 DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1028 DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1029 DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arena149fdf32018-07-04 17:03:33 +01001030#else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001031 int4 z_coords0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP;
1032 int2 z_coords1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001033
Giorgio Arena149fdf32018-07-04 17:03:33 +01001034 valid_y0 = select((int4)y_coord0.s0, (int4) - 1, z_coords0 < (int4)0);
1035 valid_y1 = select((int2)y_coord0.s0, (int2) - 1, z_coords1 < (int2)0);
1036 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, z_coords0 >= (int4)SRC_DIM_2);
1037 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, z_coords1 >= (int2)SRC_DIM_2);
1038
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001039 z_coords0 = clamp((int4)z_coords0, (int4)0, (int4)((int)SRC_DIM_2 - 1));
1040 z_coords1 = clamp((int2)z_coords1, (int2)0, (int2)((int)SRC_DIM_2 - 1));
Giorgio Arena149fdf32018-07-04 17:03:33 +01001041
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001042 DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coords0.s0 * src_stride_z);
1043 DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coords0.s1 * src_stride_z);
1044 DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coords0.s2 * src_stride_z);
1045 DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coords0.s3 * src_stride_z);
1046 DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coords1.s0 * src_stride_z);
1047 DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coords1.s1 * src_stride_z);
Giorgio Arena149fdf32018-07-04 17:03:33 +01001048#endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1049
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001050 DATA_TYPE out0 = 16.0f * d00 - 20.0f * d02 + 4.0f * d04;
1051 DATA_TYPE out1 = -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 4.0f * d04;
1052 DATA_TYPE out2 = 16.0f * d01 - 16.0f * d02 - 4.0f * d03 + 4.0f * d04;
1053 DATA_TYPE out3 = -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 4.0f * d04;
1054 DATA_TYPE out4 = 8.0f * d01 - 4.0f * d02 - 8.0f * d03 + 4.0f * d04;
1055 DATA_TYPE out5 = 16.0f * d01 - 20.0f * d03 + 4.0f * d05;
Giorgio Arena149fdf32018-07-04 17:03:33 +01001056
1057#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001058 // Row2
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001059 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
1060 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
1061 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
1062 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
1063 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
1064 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001065
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001066 DATA_TYPE d20 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1067 DATA_TYPE d21 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1068 DATA_TYPE d22 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1069 DATA_TYPE d23 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1070 DATA_TYPE d24 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1071 DATA_TYPE d25 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001072
Giorgio Arena149fdf32018-07-04 17:03:33 +01001073 out0 += k0;
1074 out1 += k1;
1075 out2 += k2;
1076 out3 += k3;
1077 out4 += k4;
1078 out5 += k5;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001079 DATA_TYPE out6 = k0;
1080 DATA_TYPE out7 = k1;
1081 DATA_TYPE out8 = k2;
1082 DATA_TYPE out9 = k3;
1083 DATA_TYPE out10 = k4;
1084 DATA_TYPE out11 = k5;
1085 DATA_TYPE out12 = k0;
1086 DATA_TYPE out13 = k1;
1087 DATA_TYPE out14 = k2;
1088 DATA_TYPE out15 = k3;
1089 DATA_TYPE out16 = k4;
1090 DATA_TYPE out17 = k5;
1091 DATA_TYPE out18 = k0;
1092 DATA_TYPE out19 = k1;
1093 DATA_TYPE out20 = k2;
1094 DATA_TYPE out21 = k3;
1095 DATA_TYPE out22 = k4;
1096 DATA_TYPE out23 = k5;
1097 DATA_TYPE out24 = k0;
1098 DATA_TYPE out25 = k1;
1099 DATA_TYPE out26 = k2;
1100 DATA_TYPE out27 = k3;
1101 DATA_TYPE out28 = k4;
1102 DATA_TYPE out29 = k5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001103
1104 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
Giorgio Arena149fdf32018-07-04 17:03:33 +01001105 out0 += -20.0f * d20 + 25.0f * d22 - 5.0f * d24;
1106 out1 += 20.0f * d21 + 20.0f * d22 - 5.0f * d23 - 5.0f * d24;
1107 out2 += -20.0f * d21 + 20.0f * d22 + 5.0f * d23 - 5.0f * d24;
1108 out3 += 10.0f * d21 + 5.0f * d22 - 10.0f * d23 - 5.0f * d24;
1109 out4 += -10.0f * d21 + 5.0f * d22 + 10.0f * d23 - 5.0f * d24;
1110 out5 += -20.0f * d21 + 25.0f * d23 - 5.0f * d25;
1111#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1112
1113 // Compute destination address
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001114#if defined(NUM_TILES_Y)
1115 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
1116#else /* defined(NUM_TILES_Y) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001117 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001118#endif /* defined(NUM_TILES_Y) */
1119
1120 uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001121
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001122 *((__global DATA_TYPE *)dst_addr) = out0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001123 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001124 *((__global DATA_TYPE *)dst_addr) = out1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001125 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001126 *((__global DATA_TYPE *)dst_addr) = out2;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001127 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001128 *((__global DATA_TYPE *)dst_addr) = out3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001129 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001130 *((__global DATA_TYPE *)dst_addr) = out4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001131 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001132 *((__global DATA_TYPE *)dst_addr) = out5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001133 dst_addr += dst_plane_stride;
1134
Giorgio Arena149fdf32018-07-04 17:03:33 +01001135#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001136 // Row1
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001137 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001138 // Row1 can never be out of bounds
1139 valid_y0 = y_coord0;
1140 valid_y1 = y_coord1;
1141
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001142 DATA_TYPE d10 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1143 DATA_TYPE d11 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1144 DATA_TYPE d12 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1145 DATA_TYPE d13 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1146 DATA_TYPE d14 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1147 DATA_TYPE d15 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001148
1149 // Row3
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001150 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
1151 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
1152 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
1153 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
1154 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
1155 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1156 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001157
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001158 DATA_TYPE d30 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1159 DATA_TYPE d31 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1160 DATA_TYPE d32 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1161 DATA_TYPE d33 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1162 DATA_TYPE d34 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1163 DATA_TYPE d35 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001164
1165 // Compute common parts for the channels between [6, 29]
1166 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
1167 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001168 DATA_TYPE part0 = -16.0f * d20 + 20.0f * d22 - 4.0f * d24;
1169 DATA_TYPE part1 = 16.0f * d10 - 20.0f * d12 + 4.0f * d14 - 4.0f * d30 + 5.0f * d32 - d34;
1170 DATA_TYPE part2 = 16.0f * d22 - 4.0f * d24;
1171 DATA_TYPE part3 = 16.0f * d21 - 4.0f * d23;
1172 DATA_TYPE part4 = 16.0f * d12 - 4.0f * d14 - 4.0f * d32 + d34;
1173 DATA_TYPE part5 = 16.0f * d11 - 4.0f * d13 - 4.0f * d31 + d33;
1174 DATA_TYPE part6 = 4.0f * d22 - 4.0f * d24;
1175 DATA_TYPE part7 = 8.0f * d11 - 8.0f * d13 - 2.0f * d31 + 2.0f * d33;
1176 DATA_TYPE part8 = 4.0f * d12 - 4.0f * d14 - d32 + d34;
1177 DATA_TYPE part9 = 8.0f * d21 - 8.0f * d23;
1178 DATA_TYPE part10 = -16.0f * d21 + 20.0f * d23 - 4.0f * d25;
1179 DATA_TYPE part11 = -16.0f * d11 + 20.0f * d13 - 4.0f * d15 + 4.0f * d31 - 5.0f * d33 + d35;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001180
1181 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
1182 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001183 DATA_TYPE part12 = 8.0f * d10 - 10.0f * d12 + 2.0f * d14 - 8.0f * d30 + 10.0f * d32 - 2.0f * d34;
1184 DATA_TYPE part13 = part0 * 0.25f; // -4.0f * d20 + 5.0f * d22 - d24
1185 DATA_TYPE part14 = part2 * 0.25f; // 4.0f * d22 - d24
1186 DATA_TYPE part15 = 8.0f * d11 - 2.0f * d13 - 8.0f * d31 + 2.0f * d33;
1187 DATA_TYPE part16 = 8.0f * d12 - 2.0f * d14 - 8.0f * d32 + 2.0f * d34;
1188 DATA_TYPE part17 = part3 * 0.25f; // 4.0f * d21 - d23
1189 DATA_TYPE part18 = part6 * 0.25f; // d22 - d24
1190 DATA_TYPE part19 = 4.0f * d11 - 4.0f * d13 - 4.0f * d31 + 4.0f * d33;
1191 DATA_TYPE part20 = 2.0f * d12 - 2.0f * d14 - 2.0f * d32 + 2.0f * d34;
1192 DATA_TYPE part21 = part9 * 0.25f; // 2.0f * (d21 - d23)
1193 DATA_TYPE part22 = part10 * 0.25f; // - 4.0f * d21 + 5.0f * d23 - d25
1194 DATA_TYPE part23 = part11 * 0.5f + 6.0f * d31 - 7.5f * d33 + 1.5f * d35; // - 8.0f * d11 + 10.0f * d13 - 2.0f * d15 + 8.0f * d31 - 10.0f * d33 + 2.0f * d35;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001195
1196 out6 += part0 - part1;
1197 out12 += part0 + part1;
1198 out7 += part2 + part3 + part4 + part5;
1199 out8 += part2 - part3 + part4 - part5;
1200 out13 += part2 + part3 - part4 - part5;
1201 out14 += part2 - part3 - part4 + part5;
1202 out9 += part6 + part7 + part8 + part9;
1203 out10 += part6 - part7 + part8 - part9;
1204 out15 += part6 - part7 - part8 + part9;
1205 out16 += part6 + part7 - part8 - part9;
1206 out11 += part10 + part11;
1207 out17 += part10 - part11;
1208
1209 out18 += part13 - part12;
1210 out24 += part13 + part12;
1211 out19 += part14 + part15 + part16 + part17;
1212 out20 += part14 - part15 + part16 - part17;
1213 out25 += part14 - part15 - part16 + part17;
1214 out26 += part14 + part15 - part16 - part17;
1215 out21 += part18 + part19 + part20 + part21;
1216 out22 += part18 - part19 + part20 - part21;
1217 out27 += part18 - part19 - part20 + part21;
1218 out28 += part18 + part19 - part20 - part21;
1219 out23 += part22 + part23;
1220 out29 += part22 - part23;
1221
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001222 *((__global DATA_TYPE *)dst_addr) = out6;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001223 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001224 *((__global DATA_TYPE *)dst_addr) = out7;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001225 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001226 *((__global DATA_TYPE *)dst_addr) = out8;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001227 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001228 *((__global DATA_TYPE *)dst_addr) = out9;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001229 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001230 *((__global DATA_TYPE *)dst_addr) = out10;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001231 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001232 *((__global DATA_TYPE *)dst_addr) = out11;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001233 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001234 *((__global DATA_TYPE *)dst_addr) = out12;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001235 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001236 *((__global DATA_TYPE *)dst_addr) = out13;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001237 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001238 *((__global DATA_TYPE *)dst_addr) = out14;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001239 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001240 *((__global DATA_TYPE *)dst_addr) = out15;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001241 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001242 *((__global DATA_TYPE *)dst_addr) = out16;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001243 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001244 *((__global DATA_TYPE *)dst_addr) = out17;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001245 dst_addr += dst_plane_stride;
1246
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001247 *((__global DATA_TYPE *)dst_addr) = out18;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001248 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001249 *((__global DATA_TYPE *)dst_addr) = out19;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001250 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001251 *((__global DATA_TYPE *)dst_addr) = out20;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001252 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001253 *((__global DATA_TYPE *)dst_addr) = out21;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001254 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001255 *((__global DATA_TYPE *)dst_addr) = out22;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001256 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001257 *((__global DATA_TYPE *)dst_addr) = out23;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001258 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001259 *((__global DATA_TYPE *)dst_addr) = out24;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001260 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001261 *((__global DATA_TYPE *)dst_addr) = out25;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001262 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001263 *((__global DATA_TYPE *)dst_addr) = out26;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001264 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001265 *((__global DATA_TYPE *)dst_addr) = out27;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001266 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001267 *((__global DATA_TYPE *)dst_addr) = out28;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001268 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001269 *((__global DATA_TYPE *)dst_addr) = out29;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001270 dst_addr += dst_plane_stride;
1271
1272 // Row5
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001273 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
1274 valid_y0 = select(y_coord0, (int4) - 1, (int4)z_coord < 0);
1275 valid_y1 = select(y_coord1, (int2) - 1, (int2)z_coord < 0);
1276 valid_y0 = select(valid_y0, (int4)SRC_DIM_1, (int4)z_coord >= (int)SRC_DIM_2);
1277 valid_y1 = select(valid_y1, (int2)SRC_DIM_1, (int2)z_coord >= (int)SRC_DIM_2);
1278 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1279 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001280
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001281 DATA_TYPE d50 = *(__global DATA_TYPE *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1282 DATA_TYPE d51 = *(__global DATA_TYPE *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1283 DATA_TYPE d52 = *(__global DATA_TYPE *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1284 DATA_TYPE d53 = *(__global DATA_TYPE *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1285 DATA_TYPE d54 = *(__global DATA_TYPE *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1286 DATA_TYPE d55 = *(__global DATA_TYPE *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001287
1288 // Channels [30, 35]
1289 out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34;
1290 out1 = -16.0f * d11 - 16.0f * d12 + 4.0f * d13 + 20.0f * d31 + 20.0f * d32 - 5.0f * d33 - 4.0f * d51 - 4.0f * d52 + d53 + d54 + 4.0f * d14 - 5.0f * d34;
1291 out2 = 16.0f * d11 - 16.0f * d12 - 4.0f * d13 - 20.0f * d31 + 20.0f * d32 + 5.0f * d33 + 4.0f * d51 - 4.0f * d52 - d53 + d54 + 4.0f * d14 - 5.0f * d34;
1292 out3 = -8.0f * d11 - 4.0f * d12 + 8.0f * d13 + 10.0f * d31 - 10.0f * d33 + 5.0f * d32 - 2.0f * d51 + 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
1293 out4 = 8.0f * d11 - 4.0f * d12 - 8.0f * d13 - 10.0f * d31 + 5.0f * d32 + 10.0f * d33 + 2.0f * d51 - 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
1294 out5 = 16.0f * d11 - 20.0f * d13 + 4.0f * d15 - 20.0f * d31 + 25.0f * d33 - 5.0f * d35 + 4.0f * d51 - 5.0f * d53 + d55;
1295
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001296 *((__global DATA_TYPE *)dst_addr) = out0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001297 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001298 *((__global DATA_TYPE *)dst_addr) = out1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001299 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001300 *((__global DATA_TYPE *)dst_addr) = out2;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001301 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001302 *((__global DATA_TYPE *)dst_addr) = out3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001303 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001304 *((__global DATA_TYPE *)dst_addr) = out4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001305 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001306 *((__global DATA_TYPE *)dst_addr) = out5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001307 dst_addr += dst_plane_stride;
Giorgio Arena149fdf32018-07-04 17:03:33 +01001308#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001309}
1310
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001311/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NHWC
Giorgio Arena149fdf32018-07-04 17:03:33 +01001312 *
1313 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1314 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001315 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1316 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001317 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
Giorgio Arena149fdf32018-07-04 17:03:33 +01001318 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001319 * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1320 * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001321 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arena149fdf32018-07-04 17:03:33 +01001322 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001323 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arena149fdf32018-07-04 17:03:33 +01001324 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1325 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1326 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1327 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1328 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1329 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1330 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1331 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1332 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1333 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1334 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1335 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1336 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1337 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1338 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001339 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1340 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001341 */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001342__kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
Giorgio Arena149fdf32018-07-04 17:03:33 +01001343 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001344 TENSOR3D_DECLARATION(dst),
1345 uint src_stride_w,
1346 uint dst_stride_w)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001347{
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001348 const int x = get_global_id(0);
1349 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001350#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001351 const int z = get_global_id(2) % NUM_TILES_Y;
1352 const int b = get_global_id(2) / NUM_TILES_Y;
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001353#else /* defined(NUM_TILES_Y) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001354 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001355#endif /* defined(NUM_TILES_Y) */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001356
1357 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001358#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001359 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001360#else /* defined(NUM_TILES_Y) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001361 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001362#endif /* defined(NUM_TILES_Y) */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001363
1364#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1365 // Clamp coordinates. This clamp is valid for all rows
1366 int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001367 y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001368
1369 // Row0
1370 // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
1371 int z_coord = z * OUTPUT_TILE_H;
1372
1373 // Load the input tile
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001374 VEC_DATA_TYPE(DATA_TYPE, 8)
1375 in_row0;
1376 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord.s0 * (int)src_stride_y + z_coord * src_stride_z);
1377 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord.s1 * (int)src_stride_y + z_coord * src_stride_z);
1378 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord.s2 * (int)src_stride_y + z_coord * src_stride_z);
1379 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord.s3 * (int)src_stride_y + z_coord * src_stride_z);
1380 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord.s4 * (int)src_stride_y + z_coord * src_stride_z);
1381 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord.s5 * (int)src_stride_y + z_coord * src_stride_z);
1382 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord.s6 * (int)src_stride_y + z_coord * src_stride_z);
1383 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001384
1385 // Calculate common factors for intermediate tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001386 VEC_DATA_TYPE(DATA_TYPE, 8)
1387 comm_fact0 = 0.0f;
1388 VEC_DATA_TYPE(DATA_TYPE, 8)
1389 tmp0 = in_row0;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001390
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001391 VEC_DATA_TYPE(DATA_TYPE, 8)
1392 out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001393
1394 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1395
1396#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1397 // We can skip the border clamping along the y dimension as we cannot read out-of-bound in case of 1x5 kernels
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001398 int y_coord = y * (int)OUTPUT_TILE_W;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001399
1400 // Row0
1401 // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
1402 int8 z_coord = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001403 int8 valid_y = select((int8)y_coord, (int8) - 1, z_coord < (int8)0); // If z < 0, set y to -1
1404 valid_y = select(valid_y, (int8)SRC_DIM_1, z_coord >= (int8)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1405 z_coord = clamp(z_coord, (int8)0, (int8)SRC_DIM_2 - 1); // Clamp z coordinate
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001406
1407 // Load the input tile
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001408 VEC_DATA_TYPE(DATA_TYPE, 8)
1409 in_row0;
1410 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord.s0 * src_stride_z);
1411 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord.s1 * src_stride_z);
1412 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord.s2 * src_stride_z);
1413 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord.s3 * src_stride_z);
1414 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord.s4 * src_stride_z);
1415 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord.s5 * src_stride_z);
1416 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord.s6 * src_stride_z);
1417 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord.s7 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001418
1419 // Calculate common factors for intermediate tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001420 VEC_DATA_TYPE(DATA_TYPE, 8)
1421 comm_fact0 = 0.0f;
1422 VEC_DATA_TYPE(DATA_TYPE, 8)
1423 tmp0 = in_row0;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001424
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001425 VEC_DATA_TYPE(DATA_TYPE, 8)
1426 out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001427
1428 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1429#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001430 VEC_DATA_TYPE(DATA_TYPE, 8)
1431 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001432
1433 // Clamp coordinates. This clamp is valid for all rows
1434 int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001435 y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001436
1437 // Row0
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001438 int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
1439 int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1
1440 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1441 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); // Clamp z coordinate
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001442
1443 // Load the input tile
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001444 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1445 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1446 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1447 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1448 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1449 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1450 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1451 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001452
1453 // Row1
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001454 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
1455 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1456 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1457 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001458
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001459 in_row1.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1460 in_row1.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1461 in_row1.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1462 in_row1.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1463 in_row1.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1464 in_row1.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1465 in_row1.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1466 in_row1.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001467
1468 // Row2
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001469 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
1470 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1471 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1472 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001473
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001474 in_row2.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1475 in_row2.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1476 in_row2.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1477 in_row2.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1478 in_row2.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1479 in_row2.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1480 in_row2.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1481 in_row2.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001482
1483 // Row3
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001484 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
1485 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1486 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1487 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001488
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001489 in_row3.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1490 in_row3.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1491 in_row3.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1492 in_row3.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1493 in_row3.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1494 in_row3.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1495 in_row3.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1496 in_row3.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001497
1498 // Row4
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001499 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 4;
1500 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1501 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1502 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001503
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001504 in_row4.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1505 in_row4.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1506 in_row4.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1507 in_row4.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1508 in_row4.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1509 in_row4.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1510 in_row4.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1511 in_row4.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001512
1513 // Row5
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001514 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
1515 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1516 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1517 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001518
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001519 in_row5.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1520 in_row5.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1521 in_row5.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1522 in_row5.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1523 in_row5.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1524 in_row5.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1525 in_row5.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1526 in_row5.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001527
1528 // Row6
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001529 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 6;
1530 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1531 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1532 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001533
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001534 in_row6.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1535 in_row6.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1536 in_row6.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1537 in_row6.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1538 in_row6.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1539 in_row6.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1540 in_row6.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1541 in_row6.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001542
1543 // Row7
Georgios Pinitasdbbe4a32018-07-18 18:06:32 +01001544 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 7;
1545 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1546 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1547 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001548
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001549 in_row7.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
1550 in_row7.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
1551 in_row7.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
1552 in_row7.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
1553 in_row7.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
1554 in_row7.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
1555 in_row7.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
1556 in_row7.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001557
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001558 VEC_DATA_TYPE(DATA_TYPE, 8)
1559 comm_fact0 = in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4;
1560 VEC_DATA_TYPE(DATA_TYPE, 8)
1561 comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3;
1562 VEC_DATA_TYPE(DATA_TYPE, 8)
1563 comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001564
1565 // Calculate intermediate tensor and reuse common factor vectors
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001566 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp0 = in_row0 - in_row6 + (DATA_TYPE)5.25f * in_row4 - (DATA_TYPE)5.25f * in_row2;
1567 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
1568 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001569
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001570 comm_fact0 = (DATA_TYPE)2.5f * in_row3;
1571 comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.f * in_row5;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001572
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001573 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
1574 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001575
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001576 comm_fact1 = (DATA_TYPE)2.f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5;
1577 comm_fact2 = (DATA_TYPE)4.f * in_row2 - (DATA_TYPE)5.f * in_row4 + in_row6;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001578
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001579 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
1580 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
1581 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * in_row3 - (DATA_TYPE)5.25f * in_row5;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001582
1583 // Calculate output rows (reuse comm_fact0 vector)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001584 VEC_DATA_TYPE(DATA_TYPE, 8)
1585 out0, out1, out2, out3, out4, out5, out6, out7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001586 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1587 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
1588 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
1589 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
1590 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
1591 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
1592 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
1593 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001594#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1595
1596 // Store values across the channels
1597#if defined(NUM_TILES_Y)
1598 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
1599#else /* NUM_TILES_Y */
1600 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
1601#endif /* NUM_TILES_Y */
1602
1603 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1604 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1605 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1606 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1607 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1608 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1609 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1610 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
1611
1612#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1613 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1614 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1615 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1616 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1617 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1618 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1619 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1620 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1621 *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1622 *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1623 *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1624 *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1625 *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1626 *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1627 *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1628 *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1629 *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1630 *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1631 *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1632 *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1633 *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1634 *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1635 *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1636 *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1637 *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1638 *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1639 *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1640 *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1641 *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1642 *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1643 *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1644 *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1645 *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1646 *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1647 *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1648 *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1649 *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1650 *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1651 *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1652 *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1653 *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1654 *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1655 *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1656 *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1657 *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1658 *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1659 *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1660 *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1661 *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1662 *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1663 *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1664 *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1665 *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1666 *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1667 *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1668 *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
1669#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1670}
1671
1672/** This OpenCL kernel computes the input transform when the kernel size is 7x7/7x1/1x7 and the output tile is 2x2/7x1/1x7 when the data layout is NHWC
1673 *
1674 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=7).
1675 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1676 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1677 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
1678 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1679 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1680 * @note If this kernel is used to perform Winograd input transform 7x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1681 * @note If this kernel is used to perform Winograd input transform 1x7, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1682 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1683 *
1684 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1685 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1686 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1687 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1688 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1689 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1690 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1691 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1692 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1693 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1694 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1695 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1696 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1697 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1698 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1699 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1700 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1701 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
1702 */
1703__kernel void winograd_input_transform_2x2_7x7_stepz1_nhwc(
1704 TENSOR3D_DECLARATION(src),
1705 TENSOR3D_DECLARATION(dst),
1706 uint src_stride_w,
1707 uint dst_stride_w)
1708{
1709 const int x = get_global_id(0);
1710 const int y = get_global_id(1);
1711#if defined(NUM_TILES_Y)
1712 const int z = get_global_id(2) % NUM_TILES_Y;
1713 const int b = get_global_id(2) / NUM_TILES_Y;
1714#else /* defined(NUM_TILES_Y) */
1715 const int z = get_global_id(2);
1716#endif /* defined(NUM_TILES_Y) */
1717
1718 // Compute input address
1719#if defined(NUM_TILES_Y)
1720 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
1721#else /* defined(NUM_TILES_Y) */
1722 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
1723#endif /* defined(NUM_TILES_Y) */
1724
1725#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1726
1727 // Clamp coordinates. This clamp is valid for all rows
1728 int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
1729 y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
1730
1731 // Clamp coordinates. This clamp is valid for all columns
1732 int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
1733 int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1
1734 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1735 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1736
1737 // Load the input tile
1738 VEC_DATA_TYPE(DATA_TYPE, 8)
1739 in_row0;
1740 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
1741 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
1742 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
1743 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
1744 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
1745 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
1746 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
1747 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
1748
1749 VEC_DATA_TYPE(DATA_TYPE, 8)
1750 out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
1751
1752 VEC_DATA_TYPE(DATA_TYPE, 8)
1753 tmp0 = ((VEC_DATA_TYPE(DATA_TYPE, 8)) - 36.0f) * in_row0;
1754
1755 VEC_DATA_TYPE(DATA_TYPE, 8)
1756 comm_fact0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
1757
1758 OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
1759
1760#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1761 // We can skip the border clamping along the y dimension as we cannot read out-of-bound in case of 1x5 kernels
1762 int y_coord = y * (int)OUTPUT_TILE_W;
1763
1764 // Row0
1765 // We can skip the border clamping along the z dimension as we cannot read out-of-bound in case of 5x1 kernels
1766 int8 z_coord = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
1767 int8 valid_y = select((int8)y_coord, (int8) - 1, z_coord < (int8)0); // If z < 0, set y to -1
1768 valid_y = select(valid_y, (int8)SRC_DIM_1, z_coord >= (int8)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1769 z_coord = clamp(z_coord, (int8)0, (int8)SRC_DIM_2 - 1); // Clamp z coordinate
1770
1771 // Load the input tile
1772 VEC_DATA_TYPE(DATA_TYPE, 8)
1773 in_row0;
1774 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord.s0 * (int)src_stride_z);
1775 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord.s1 * (int)src_stride_z);
1776 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord.s2 * (int)src_stride_z);
1777 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord.s3 * (int)src_stride_z);
1778 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord.s4 * (int)src_stride_z);
1779 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord.s5 * (int)src_stride_z);
1780 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord.s6 * (int)src_stride_z);
1781 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord.s7 * (int)src_stride_z);
1782
1783 // Calculate common factors for intermediate tensor
1784 VEC_DATA_TYPE(DATA_TYPE, 8)
1785 tmp0 = ((VEC_DATA_TYPE(DATA_TYPE, 8)) - 36.0f) * in_row0;
1786
1787 VEC_DATA_TYPE(DATA_TYPE, 8)
1788 out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
1789
1790 VEC_DATA_TYPE(DATA_TYPE, 8)
1791 comm_fact0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
1792
1793 OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
1794#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1795 VEC_DATA_TYPE(DATA_TYPE, 8)
1796 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
1797
1798 // Clamp coordinates. This clamp is valid for all rows
1799 int8 y_coord = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
1800 y_coord = clamp(y_coord, (int8) - 1, (int8)SRC_DIM_1);
1801
1802 // Row0
1803 int z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 0;
1804 int8 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0); // If z < 0, set y to -1
1805 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
1806 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1); // Clamp z coordinate
1807
1808 // Load the input tile
1809 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
1810 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
1811 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
1812 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
1813 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
1814 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
1815 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
1816 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
1817
1818 // Row1
1819 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 1;
1820 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1821 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1822 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1823
1824 in_row1.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
1825 in_row1.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
1826 in_row1.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
1827 in_row1.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
1828 in_row1.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
1829 in_row1.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
1830 in_row1.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
1831 in_row1.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
1832
1833 // Row2
1834 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 2;
1835 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1836 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1837 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1838
1839 in_row2.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
1840 in_row2.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
1841 in_row2.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
1842 in_row2.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
1843 in_row2.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
1844 in_row2.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
1845 in_row2.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
1846 in_row2.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
1847
1848 // Row3
1849 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 3;
1850 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1851 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1852 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1853
1854 in_row3.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
1855 in_row3.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
1856 in_row3.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
1857 in_row3.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
1858 in_row3.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
1859 in_row3.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
1860 in_row3.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
1861 in_row3.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
1862
1863 // Row4
1864 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 4;
1865 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1866 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1867 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1868
1869 in_row4.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
1870 in_row4.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
1871 in_row4.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
1872 in_row4.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
1873 in_row4.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
1874 in_row4.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
1875 in_row4.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
1876 in_row4.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
1877
1878 // Row5
1879 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 5;
1880 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1881 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1882 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1883
1884 in_row5.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
1885 in_row5.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
1886 in_row5.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
1887 in_row5.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
1888 in_row5.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
1889 in_row5.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
1890 in_row5.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
1891 in_row5.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
1892
1893 // Row6
1894 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 6;
1895 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1896 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1897 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1898
1899 in_row6.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
1900 in_row6.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
1901 in_row6.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
1902 in_row6.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
1903 in_row6.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
1904 in_row6.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
1905 in_row6.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
1906 in_row6.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
1907
1908 // Row7
1909 z_coord = (z * (int)OUTPUT_TILE_H) - (int)PAD_TOP + 7;
1910 valid_y = select(y_coord, (int8) - 1, (int8)z_coord < 0);
1911 valid_y = select(valid_y, (int8)SRC_DIM_1, (int8)z_coord >= (int)SRC_DIM_2);
1912 z_coord = clamp(z_coord, 0, (int)SRC_DIM_2 - 1);
1913
1914 in_row7.s0 = *(__global DATA_TYPE *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * (int)src_stride_z);
1915 in_row7.s1 = *(__global DATA_TYPE *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * (int)src_stride_z);
1916 in_row7.s2 = *(__global DATA_TYPE *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * (int)src_stride_z);
1917 in_row7.s3 = *(__global DATA_TYPE *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * (int)src_stride_z);
1918 in_row7.s4 = *(__global DATA_TYPE *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * (int)src_stride_z);
1919 in_row7.s5 = *(__global DATA_TYPE *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * (int)src_stride_z);
1920 in_row7.s6 = *(__global DATA_TYPE *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * (int)src_stride_z);
1921 in_row7.s7 = *(__global DATA_TYPE *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * (int)src_stride_z);
1922
1923 VEC_DATA_TYPE(DATA_TYPE, 8)
1924 comm_fact0 = (DATA_TYPE)36.0f * in_row2 - (DATA_TYPE)13.0f * in_row4 + in_row6;
1925 VEC_DATA_TYPE(DATA_TYPE, 8)
1926 comm_fact1 = (DATA_TYPE)36.0f * in_row1 - (DATA_TYPE)13.0f * in_row3 + in_row5;
1927 VEC_DATA_TYPE(DATA_TYPE, 8)
1928 comm_fact2 = (DATA_TYPE)9.0f * in_row2 - (DATA_TYPE)10.0f * in_row4 + in_row6;
1929 VEC_DATA_TYPE(DATA_TYPE, 8)
1930 comm_fact3 = (DATA_TYPE)18.0f * in_row1 - (DATA_TYPE)20.0f * in_row3 + (DATA_TYPE)2.0f * in_row5;
1931 VEC_DATA_TYPE(DATA_TYPE, 8)
1932 comm_fact4 = (DATA_TYPE)4.0f * in_row2 - (DATA_TYPE)5.0f * in_row4 + in_row6;
1933 VEC_DATA_TYPE(DATA_TYPE, 8)
1934 comm_fact5 = (DATA_TYPE)12.0f * in_row1 - (DATA_TYPE)15.0f * in_row3 + (DATA_TYPE)3.0f * in_row5;
1935
1936 // Calculate intermediate tensors
1937 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp0 = -(DATA_TYPE)36.0f * in_row0 + (DATA_TYPE)49.0f * in_row2 - (DATA_TYPE)14.0f * in_row4 + in_row6;
1938 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 - comm_fact1;
1939 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 + comm_fact1;
1940 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact2 - comm_fact3;
1941 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 + comm_fact3;
1942 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact4 - comm_fact5;
1943 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact4 + comm_fact5;
1944 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = -(DATA_TYPE)36.0f * in_row1 + (DATA_TYPE)49.0f * in_row3 - (DATA_TYPE)14.0f * in_row5 + in_row7;
1945
1946 VEC_DATA_TYPE(DATA_TYPE, 8)
1947 out0, out1, out2, out3, out4, out5, out6, out7;
1948
1949 OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
1950 OUTPUT_ROW_2x2_7x7(out1, tmp1, comm_fact0);
1951 OUTPUT_ROW_2x2_7x7(out2, tmp2, comm_fact0);
1952 OUTPUT_ROW_2x2_7x7(out3, tmp3, comm_fact0);
1953 OUTPUT_ROW_2x2_7x7(out4, tmp4, comm_fact0);
1954 OUTPUT_ROW_2x2_7x7(out5, tmp5, comm_fact0);
1955 OUTPUT_ROW_2x2_7x7(out6, tmp6, comm_fact0);
1956 OUTPUT_ROW_2x2_7x7(out7, tmp7, comm_fact0);
1957
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001958#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001959
1960 // Store values across the channels
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001961#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001962 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001963#else /* NUM_TILES_Y */
1964 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
1965#endif /* NUM_TILES_Y */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001966
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001967 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1968 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1969 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1970 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1971 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1972 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1973 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1974 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001975
1976#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001977 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1978 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1979 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1980 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1981 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1982 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1983 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1984 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1985 *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1986 *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1987 *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1988 *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1989 *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1990 *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1991 *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1992 *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1993 *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1994 *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1995 *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1996 *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1997 *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1998 *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1999 *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
2000 *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
2001 *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
2002 *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
2003 *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
2004 *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
2005 *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
2006 *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
2007 *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
2008 *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
2009 *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
2010 *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
2011 *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
2012 *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
2013 *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
2014 *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
2015 *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
2016 *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
2017 *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
2018 *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
2019 *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
2020 *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
2021 *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
2022 *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
2023 *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
2024 *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
2025 *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
2026 *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
2027 *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
2028 *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
2029 *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
2030 *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
2031 *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
2032 *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002033#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena149fdf32018-07-04 17:03:33 +01002034}
Georgios Pinitasffb57a02018-10-29 18:01:52 +00002035#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002036
2037#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
2038/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
2039 *
2040 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2041 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2042 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
2043 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2044 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002045 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002046 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002047 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002048 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2049 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2050 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2051 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2052 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2053 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2054 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2055 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2056 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2057 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2058 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2059 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2060 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2061 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2062 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002063 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2064 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002065 */
2066__kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
2067 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002068 TENSOR3D_DECLARATION(dst),
2069 uint src_stride_w,
2070 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002071{
2072 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
2073 src_stride_x,
2074 src_step_x,
2075 src_stride_y,
2076 src_step_y,
2077 src_stride_z,
2078 src_step_z,
2079 src_offset_first_element_in_bytes,
2080 dst_ptr,
2081 dst_stride_x,
2082 dst_step_x,
2083 dst_stride_y,
2084 dst_step_y,
2085 dst_stride_z,
2086 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002087 dst_offset_first_element_in_bytes,
2088 src_stride_w,
2089 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002090}
2091
2092/** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
2093 *
2094 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2095 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2096 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
2097 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2098 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002099 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002100 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002101 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002102 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2103 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2104 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2105 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2106 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2107 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2108 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2109 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2110 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2111 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2112 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2113 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2114 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2115 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2116 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002117 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2118 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002119 */
2120__kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
2121 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002122 TENSOR3D_DECLARATION(dst),
2123 uint src_stride_w,
2124 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002125{
2126 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
2127 src_stride_x,
2128 src_step_x,
2129 src_stride_y,
2130 src_step_y,
2131 src_stride_z,
2132 src_step_z,
2133 src_offset_first_element_in_bytes,
2134 dst_ptr,
2135 dst_stride_x,
2136 dst_step_x,
2137 dst_stride_y,
2138 dst_step_y,
2139 dst_stride_z,
2140 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002141 dst_offset_first_element_in_bytes,
2142 src_stride_w,
2143 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002144}
2145
2146/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
2147 *
2148 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2149 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2150 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
2151 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2152 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002153 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002154 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002155 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002156 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2157 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2158 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2159 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2160 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2161 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2162 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2163 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2164 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2165 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2166 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2167 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2168 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2169 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2170 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002171 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2172 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002173 */
2174__kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
2175 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002176 TENSOR3D_DECLARATION(dst),
2177 uint src_stride_w,
2178 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002179{
2180 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
2181 src_stride_x,
2182 src_step_x,
2183 src_stride_y,
2184 src_step_y,
2185 src_stride_z,
2186 src_step_z,
2187 src_offset_first_element_in_bytes,
2188 dst_ptr,
2189 dst_stride_x,
2190 dst_step_x,
2191 dst_stride_y,
2192 dst_step_y,
2193 dst_stride_z,
2194 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002195 dst_offset_first_element_in_bytes,
2196 src_stride_w,
2197 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002198}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002199
2200/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 when the data layout is NCHW
2201 *
2202 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2203 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2204 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
2205 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2206 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002207 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002208 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002209 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002210 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2211 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2212 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2213 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2214 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2215 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2216 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2217 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2218 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2219 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2220 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2221 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2222 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2223 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2224 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002225 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2226 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002227 */
2228__kernel void winograd_input_transform_4x1_5x1_stepz1_nchw(
2229 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002230 TENSOR3D_DECLARATION(dst),
2231 uint src_stride_w,
2232 uint dst_stride_w)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002233{
2234 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
2235 src_stride_x,
2236 src_step_x,
2237 src_stride_y,
2238 src_step_y,
2239 src_stride_z,
2240 src_step_z,
2241 src_offset_first_element_in_bytes,
2242 dst_ptr,
2243 dst_stride_x,
2244 dst_step_x,
2245 dst_stride_y,
2246 dst_step_y,
2247 dst_stride_z,
2248 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002249 dst_offset_first_element_in_bytes,
2250 src_stride_w,
2251 dst_stride_w);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002252}
2253
Georgios Pinitasffb57a02018-10-29 18:01:52 +00002254#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002255/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1 for data layout NHWC
2256 *
2257 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2258 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2259 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2260 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2261 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
2262 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2263 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002264 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002265 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002266 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002267 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2268 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2269 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2270 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2272 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2273 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2274 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2275 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2276 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2277 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2278 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2279 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2280 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2281 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002282 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2283 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002284 */
2285__kernel void winograd_input_transform_4x1_3x1_stepz1_nhwc(
2286 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002287 TENSOR3D_DECLARATION(dst),
2288 uint src_stride_w,
2289 uint dst_stride_w)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002290{
2291 winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
2292 src_stride_x,
2293 src_step_x,
2294 src_stride_y,
2295 src_step_y,
2296 src_stride_z,
2297 src_step_z,
2298 src_offset_first_element_in_bytes,
2299 dst_ptr,
2300 dst_stride_x,
2301 dst_step_x,
2302 dst_stride_y,
2303 dst_step_y,
2304 dst_stride_z,
2305 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002306 dst_offset_first_element_in_bytes,
2307 src_stride_w,
2308 dst_stride_w);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002309}
2310
2311/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 for data layout NHWC
2312 *
2313 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2314 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2315 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2316 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2317 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
2318 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2319 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002320 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002321 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002322 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002323 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2324 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2325 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2326 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2327 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2328 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2329 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2330 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2331 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2332 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2333 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2334 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2335 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2336 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2337 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002338 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2339 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002340 */
2341__kernel void winograd_input_transform_4x1_5x1_stepz1_nhwc(
2342 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002343 TENSOR3D_DECLARATION(dst),
2344 uint src_stride_w,
2345 uint dst_stride_w)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002346{
2347 winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
2348 src_stride_x,
2349 src_step_x,
2350 src_stride_y,
2351 src_step_y,
2352 src_stride_z,
2353 src_step_z,
2354 src_offset_first_element_in_bytes,
2355 dst_ptr,
2356 dst_stride_x,
2357 dst_step_x,
2358 dst_stride_y,
2359 dst_step_y,
2360 dst_stride_z,
2361 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002362 dst_offset_first_element_in_bytes,
2363 src_stride_w,
2364 dst_stride_w);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002365}
Michele Di Giorgiof955d512019-02-27 14:26:51 +00002366
2367/** This OpenCL kernel computes the input transform when the kernel size is 7x1 and the output tile is 2x1 for data layout NHWC
2368 *
2369 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=7).
2370 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2371 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2372 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2373 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=7
2374 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2375 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
2376 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
2377 *
2378 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
2379 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2380 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2381 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2382 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2383 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2384 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2385 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2386 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2387 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2388 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2389 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2390 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2391 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2392 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2393 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2394 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2395 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
2396 */
2397__kernel void winograd_input_transform_2x1_7x1_stepz1_nhwc(
2398 TENSOR3D_DECLARATION(src),
2399 TENSOR3D_DECLARATION(dst),
2400 uint src_stride_w,
2401 uint dst_stride_w)
2402{
2403 winograd_input_transform_2x2_7x7_stepz1_nhwc(src_ptr,
2404 src_stride_x,
2405 src_step_x,
2406 src_stride_y,
2407 src_step_y,
2408 src_stride_z,
2409 src_step_z,
2410 src_offset_first_element_in_bytes,
2411 dst_ptr,
2412 dst_stride_x,
2413 dst_step_x,
2414 dst_stride_y,
2415 dst_step_y,
2416 dst_stride_z,
2417 dst_step_z,
2418 dst_offset_first_element_in_bytes,
2419 src_stride_w,
2420 dst_stride_w);
2421}
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002422#endif // defined(NUM_TILES_Y) && defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002423#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
2424
2425#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
2426/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
2427 *
2428 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2429 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2430 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2431 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2432 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002433 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002434 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002435 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002436 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2437 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2438 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2439 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2440 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2441 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2442 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2443 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2444 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2445 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2446 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2447 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2448 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2449 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2450 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002451 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2452 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002453 */
2454__kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
2455 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002456 TENSOR3D_DECLARATION(dst),
2457 uint src_stride_w,
2458 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002459{
2460 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
2461 src_stride_x,
2462 src_step_x,
2463 src_stride_y,
2464 src_step_y,
2465 src_stride_z,
2466 src_step_z,
2467 src_offset_first_element_in_bytes,
2468 dst_ptr,
2469 dst_stride_x,
2470 dst_step_x,
2471 dst_stride_y,
2472 dst_step_y,
2473 dst_stride_z,
2474 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002475 dst_offset_first_element_in_bytes,
2476 src_stride_w,
2477 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002478}
2479
2480/** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
2481 *
2482 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2483 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2484 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2485 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2486 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002487 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002488 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002489 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002490 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2491 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2492 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2493 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2494 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2495 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2496 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2497 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2498 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2499 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2500 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2501 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2502 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2503 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2504 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002505 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2506 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002507 */
2508__kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
2509 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002510 TENSOR3D_DECLARATION(dst),
2511 uint src_stride_w,
2512 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002513{
2514 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
2515 src_stride_x,
2516 src_step_x,
2517 src_stride_y,
2518 src_step_y,
2519 src_stride_z,
2520 src_step_z,
2521 src_offset_first_element_in_bytes,
2522 dst_ptr,
2523 dst_stride_x,
2524 dst_step_x,
2525 dst_stride_y,
2526 dst_step_y,
2527 dst_stride_z,
2528 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002529 dst_offset_first_element_in_bytes,
2530 src_stride_w,
2531 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002532}
2533
2534/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
2535 *
2536 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2537 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2538 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2539 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2540 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002541 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002542 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002543 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002544 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2545 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2546 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2547 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2548 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2549 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2550 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2551 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2552 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2553 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2554 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2555 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2556 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2557 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2558 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002559 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2560 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002561 */
2562__kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
2563 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002564 TENSOR3D_DECLARATION(dst),
2565 uint src_stride_w,
2566 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002567{
2568 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
2569 src_stride_x,
2570 src_step_x,
2571 src_stride_y,
2572 src_step_y,
2573 src_stride_z,
2574 src_step_z,
2575 src_offset_first_element_in_bytes,
2576 dst_ptr,
2577 dst_stride_x,
2578 dst_step_x,
2579 dst_stride_y,
2580 dst_step_y,
2581 dst_stride_z,
2582 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002583 dst_offset_first_element_in_bytes,
2584 src_stride_w,
2585 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002586}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002587
2588/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4
2589 *
2590 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2591 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2592 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2593 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2594 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002595 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002596 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002597 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002598 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2599 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2600 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2601 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2602 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2603 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2604 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2605 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2606 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2607 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2608 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2609 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2610 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2611 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2612 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002613 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2614 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002615 */
2616__kernel void winograd_input_transform_1x4_1x5_stepz1_nchw(
2617 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002618 TENSOR3D_DECLARATION(dst),
2619 uint src_stride_w,
2620 uint dst_stride_w)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002621{
2622 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
2623 src_stride_x,
2624 src_step_x,
2625 src_stride_y,
2626 src_step_y,
2627 src_stride_z,
2628 src_step_z,
2629 src_offset_first_element_in_bytes,
2630 dst_ptr,
2631 dst_stride_x,
2632 dst_step_x,
2633 dst_stride_y,
2634 dst_step_y,
2635 dst_stride_z,
2636 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002637 dst_offset_first_element_in_bytes,
2638 src_stride_w,
2639 dst_stride_w);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002640}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002641
Georgios Pinitasffb57a02018-10-29 18:01:52 +00002642#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002643/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4 for data layout NHWC
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002644 *
2645 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002646 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2647 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002648 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002649 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002650 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002651 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002652 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002653 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002654 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002655 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2656 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2657 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2658 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2659 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2660 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2661 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2662 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2663 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2664 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2665 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2666 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2667 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2668 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2669 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002670 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2671 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002672 */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002673__kernel void winograd_input_transform_1x4_1x3_stepz1_nhwc(
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002674 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002675 TENSOR3D_DECLARATION(dst),
2676 uint src_stride_w,
2677 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002678{
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002679 winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
2680 src_stride_x,
2681 src_step_x,
2682 src_stride_y,
2683 src_step_y,
2684 src_stride_z,
2685 src_step_z,
2686 src_offset_first_element_in_bytes,
2687 dst_ptr,
2688 dst_stride_x,
2689 dst_step_x,
2690 dst_stride_y,
2691 dst_step_y,
2692 dst_stride_z,
2693 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002694 dst_offset_first_element_in_bytes,
2695 src_stride_w,
2696 dst_stride_w);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002697}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002698
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002699/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4 for data layout NHWC
2700 *
2701 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2702 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2703 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2704 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2705 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2706 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2707 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002708 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002709 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002710 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002711 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2712 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2713 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2714 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2715 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2716 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2717 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2718 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2719 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2720 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2721 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2722 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2723 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2724 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2725 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002726 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2727 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002728 */
2729__kernel void winograd_input_transform_1x4_1x5_stepz1_nhwc(
2730 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002731 TENSOR3D_DECLARATION(dst),
2732 uint src_stride_w,
2733 uint dst_stride_w)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002734{
2735 winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
2736 src_stride_x,
2737 src_step_x,
2738 src_stride_y,
2739 src_step_y,
2740 src_stride_z,
2741 src_step_z,
2742 src_offset_first_element_in_bytes,
2743 dst_ptr,
2744 dst_stride_x,
2745 dst_step_x,
2746 dst_stride_y,
2747 dst_step_y,
2748 dst_stride_z,
2749 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002750 dst_offset_first_element_in_bytes,
2751 src_stride_w,
2752 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002753}
Michele Di Giorgiof955d512019-02-27 14:26:51 +00002754
2755/** This OpenCL kernel computes the input transform when the kernel size is 1x7 and the output tile is 1x2 for data layout NHWC
2756 *
2757 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=7).
2758 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2759 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2760 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2761 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2762 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=7
2763 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
2764 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
2765 *
2766 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
2767 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2768 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2769 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2770 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2771 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2772 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2773 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2774 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2775 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2776 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2777 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2778 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2779 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2780 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2781 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2782 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2783 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
2784 */
2785__kernel void winograd_input_transform_1x2_1x7_stepz1_nhwc(
2786 TENSOR3D_DECLARATION(src),
2787 TENSOR3D_DECLARATION(dst),
2788 uint src_stride_w,
2789 uint dst_stride_w)
2790{
2791 winograd_input_transform_2x2_7x7_stepz1_nhwc(src_ptr,
2792 src_stride_x,
2793 src_step_x,
2794 src_stride_y,
2795 src_step_y,
2796 src_stride_z,
2797 src_step_z,
2798 src_offset_first_element_in_bytes,
2799 dst_ptr,
2800 dst_stride_x,
2801 dst_step_x,
2802 dst_stride_y,
2803 dst_step_y,
2804 dst_stride_z,
2805 dst_step_z,
2806 dst_offset_first_element_in_bytes,
2807 src_stride_w,
2808 dst_stride_w);
2809}
Georgios Pinitasffb57a02018-10-29 18:01:52 +00002810#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002811#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Michele Di Giorgiof955d512019-02-27 14:26:51 +00002812#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)