blob: 8c382183c3de37ded729b9cad5329650e58149f8 [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
Giorgio Arena049989a2021-03-22 17:02:26 +00002 * Copyright (c) 2018-2021 Arm Limited.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
Gian Marco Iodice534b8892021-04-01 16:17:16 +010025#include "tile_helpers.h"
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010026
Gian Marco Iodice534b8892021-04-01 16:17:16 +010027#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +000028 ({ \
Gian Marco Iodice534b8892021-04-01 16:17:16 +010029 comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
30 comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
31 comm_fact.s2 = 2.5f * tmp.s3; \
32 comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
33 comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
34 comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
35 comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +000036 \
Gian Marco Iodice534b8892021-04-01 16:17:16 +010037 out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
Gian Marco Iodiced28b7512018-07-06 12:59:28 +010038 out.s1 = comm_fact.s0 + comm_fact.s1; \
39 out.s2 = comm_fact.s0 - comm_fact.s1; \
40 out.s3 = comm_fact.s3 + comm_fact.s4; \
41 out.s4 = comm_fact.s4 - comm_fact.s3; \
42 out.s5 = comm_fact.s5 + comm_fact.s6; \
43 out.s6 = comm_fact.s5 - comm_fact.s6; \
Gian Marco Iodice534b8892021-04-01 16:17:16 +010044 out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
Gian Marco Iodiced28b7512018-07-06 12:59:28 +010045 })
46
Michele Di Giorgiof955d512019-02-27 14:26:51 +000047#define OUTPUT_ROW_2x2_7x7(out, tmp, comm_fact) \
48 ({ \
49 comm_fact.s0 = 36.0f * tmp.s2 - 13.0f * tmp.s4 + tmp.s6; \
50 comm_fact.s1 = 36.0f * tmp.s1 - 13.0f * tmp.s3 + 1.0f * tmp.s5; \
51 comm_fact.s2 = 9.0f * tmp.s2 - 10.0f * tmp.s4 + tmp.s6; \
52 comm_fact.s3 = 18.0f * tmp.s1 - 20.0f * tmp.s3 + 2.0f * tmp.s5; \
53 comm_fact.s4 = 4.0f * tmp.s2 - 5.0f * tmp.s4 + tmp.s6; \
54 comm_fact.s5 = 12.0f * tmp.s1 - 15.0f * tmp.s3 + 3.0f * tmp.s5; \
55 out.s0 = -36.0f * tmp.s0 + 49.0f * tmp.s2 + -14.0f * tmp.s4 + tmp.s6; \
56 out.s1 = comm_fact.s0 - comm_fact.s1; \
57 out.s2 = comm_fact.s0 + comm_fact.s1; \
58 out.s3 = comm_fact.s2 - comm_fact.s3; \
59 out.s4 = comm_fact.s2 + comm_fact.s3; \
60 out.s5 = comm_fact.s4 - comm_fact.s5; \
61 out.s6 = comm_fact.s4 + comm_fact.s5; \
62 out.s7 = -36.0f * tmp.s1 + 0.0f * tmp.s2 + 49.0f * tmp.s3 - 14.0f * tmp.s5 + tmp.s7; \
63 })
64
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010065#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
66/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
67 *
68 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
69 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
70 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
71 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
72 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
73 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010074 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010075 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +010076 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010077 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
78 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
79 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
80 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
81 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
82 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
83 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
84 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
85 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
86 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
87 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
88 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
89 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
90 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
91 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +010092 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
93 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +010094 */
95__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
96 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +010097 TENSOR3D_DECLARATION(dst),
98 uint src_stride_w,
99 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100100{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100101 const int x = get_global_id(0);
102 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000103#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100104 const int z = get_global_id(2) % SRC_DEPTH;
105 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000106#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000107 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000108#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100109
110 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000111#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100112 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000113#else /* defined(SRC_DEPTH) */
114 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
115#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100116
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100117 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100118
119#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100120 VEC_DATA_TYPE(DATA_TYPE, 4)
121 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100122#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100123 VEC_DATA_TYPE(DATA_TYPE, 4)
124 in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
125 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
126 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
127 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100128#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100129 VEC_DATA_TYPE(DATA_TYPE, 4)
130 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
131 VEC_DATA_TYPE(DATA_TYPE, 4)
132 in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
133 VEC_DATA_TYPE(DATA_TYPE, 4)
134 in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
135 VEC_DATA_TYPE(DATA_TYPE, 4)
136 in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100137#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
138
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100139 VEC_DATA_TYPE(DATA_TYPE, 4)
140 tmp0 = in_row0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100141
142#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
143 tmp0 -= in_row2;
144#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
145
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100146 DATA_TYPE out00 = tmp0.s0 - tmp0.s2;
147 DATA_TYPE out01 = tmp0.s1 + tmp0.s2;
148 DATA_TYPE out02 = tmp0.s2 - tmp0.s1;
149 DATA_TYPE out03 = tmp0.s1 - tmp0.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100150
151#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100152 VEC_DATA_TYPE(DATA_TYPE, 4)
153 tmp1 = in_row1 + in_row2;
154 VEC_DATA_TYPE(DATA_TYPE, 4)
155 tmp2 = in_row2 - in_row1;
156 VEC_DATA_TYPE(DATA_TYPE, 4)
157 tmp3 = in_row1 - in_row3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100158
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100159 DATA_TYPE out10 = tmp1.s0 - tmp1.s2;
160 DATA_TYPE out11 = tmp1.s1 + tmp1.s2;
161 DATA_TYPE out12 = tmp1.s2 - tmp1.s1;
162 DATA_TYPE out13 = tmp1.s1 - tmp1.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100163
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100164 DATA_TYPE out20 = tmp2.s0 - tmp2.s2;
165 DATA_TYPE out21 = tmp2.s1 + tmp2.s2;
166 DATA_TYPE out22 = tmp2.s2 - tmp2.s1;
167 DATA_TYPE out23 = tmp2.s1 - tmp2.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100168
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100169 DATA_TYPE out30 = tmp3.s0 - tmp3.s2;
170 DATA_TYPE out31 = tmp3.s1 + tmp3.s2;
171 DATA_TYPE out32 = tmp3.s2 - tmp3.s1;
172 DATA_TYPE out33 = tmp3.s1 - tmp3.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100173#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
174
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000175#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100176 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000177#else /* defined(SRC_DEPTH) */
178 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
179#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100180
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100181 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
182 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
183 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
184 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100185
186#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100187 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out10;
188 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out11;
189 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out12;
190 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out13;
191 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out20;
192 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out21;
193 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out22;
194 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out23;
195 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out30;
196 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out31;
197 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out32;
198 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out33;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100199#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
200}
201
202/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
203 *
204 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
205 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
206 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
207 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
208 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
209 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100210 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100211 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100212 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100213 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
214 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
215 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
216 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
217 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
218 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
219 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
220 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
221 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
222 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
223 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
224 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
225 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
226 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
227 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100228 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
229 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100230 */
231__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
232 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100233 TENSOR3D_DECLARATION(dst),
234 uint src_stride_w,
235 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100236{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100237 const int x = get_global_id(0);
238 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000239#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100240 const int z = (get_global_id(2) * 2) % SRC_DEPTH;
241 const int b = (get_global_id(2) * 2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000242#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000243 const int z = get_global_id(2) * 2;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000244#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100245
246 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000247#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100248 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000249#else /* defined(SRC_DEPTH) */
250 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
251#endif /* defined(SRC_DEPTH) */
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100252 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100253
254#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100255 VEC_DATA_TYPE(DATA_TYPE, 4)
256 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100257#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100258 VEC_DATA_TYPE(DATA_TYPE, 4)
259 in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
260 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
261 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
262 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100263#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100264 VEC_DATA_TYPE(DATA_TYPE, 4)
265 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
266 VEC_DATA_TYPE(DATA_TYPE, 4)
267 in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
268 VEC_DATA_TYPE(DATA_TYPE, 4)
269 in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
270 VEC_DATA_TYPE(DATA_TYPE, 4)
271 in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100272#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
273
274 src_addr += src_stride_z;
275#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100276 VEC_DATA_TYPE(DATA_TYPE, 4)
277 in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100278#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100279 VEC_DATA_TYPE(DATA_TYPE, 4)
280 in_row4 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
281 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
282 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
283 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100284#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100285 VEC_DATA_TYPE(DATA_TYPE, 4)
286 in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
287 VEC_DATA_TYPE(DATA_TYPE, 4)
288 in_row5 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
289 VEC_DATA_TYPE(DATA_TYPE, 4)
290 in_row6 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
291 VEC_DATA_TYPE(DATA_TYPE, 4)
292 in_row7 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100293#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
294
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100295 VEC_DATA_TYPE(DATA_TYPE, 4)
296 tmp0 = in_row0;
297 VEC_DATA_TYPE(DATA_TYPE, 4)
298 tmp4 = in_row4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100299
300#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
301 tmp0 -= in_row2;
302 tmp4 -= in_row6;
303#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
304
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100305 VEC_DATA_TYPE(DATA_TYPE, 2)
306 out00 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
307 VEC_DATA_TYPE(DATA_TYPE, 2)
308 out01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
309 VEC_DATA_TYPE(DATA_TYPE, 2)
310 out02 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
311 VEC_DATA_TYPE(DATA_TYPE, 2)
312 out03 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100313
314#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100315 VEC_DATA_TYPE(DATA_TYPE, 4)
316 tmp1 = in_row1 + in_row2;
317 VEC_DATA_TYPE(DATA_TYPE, 4)
318 tmp2 = in_row2 - in_row1;
319 VEC_DATA_TYPE(DATA_TYPE, 4)
320 tmp3 = in_row1 - in_row3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100321
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100322 VEC_DATA_TYPE(DATA_TYPE, 4)
323 tmp5 = in_row5 + in_row6;
324 VEC_DATA_TYPE(DATA_TYPE, 4)
325 tmp6 = in_row6 - in_row5;
326 VEC_DATA_TYPE(DATA_TYPE, 4)
327 tmp7 = in_row5 - in_row7;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100328
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100329 VEC_DATA_TYPE(DATA_TYPE, 2)
330 out10 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
331 VEC_DATA_TYPE(DATA_TYPE, 2)
332 out11 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
333 VEC_DATA_TYPE(DATA_TYPE, 2)
334 out12 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
335 VEC_DATA_TYPE(DATA_TYPE, 2)
336 out13 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100337
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100338 VEC_DATA_TYPE(DATA_TYPE, 2)
339 out20 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
340 VEC_DATA_TYPE(DATA_TYPE, 2)
341 out21 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
342 VEC_DATA_TYPE(DATA_TYPE, 2)
343 out22 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
344 VEC_DATA_TYPE(DATA_TYPE, 2)
345 out23 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100346
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100347 VEC_DATA_TYPE(DATA_TYPE, 2)
348 out30 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
349 VEC_DATA_TYPE(DATA_TYPE, 2)
350 out31 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
351 VEC_DATA_TYPE(DATA_TYPE, 2)
352 out32 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
353 VEC_DATA_TYPE(DATA_TYPE, 2)
354 out33 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100355#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
356
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000357#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100358 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000359#else /* defined(SRC_DEPTH) */
360 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
361#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100362
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100363 vstore2(out00, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z));
364 vstore2(out01, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z));
365 vstore2(out02, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z));
366 vstore2(out03, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100367
368#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100369 vstore2(out10, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z));
370 vstore2(out11, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z));
371 vstore2(out12, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z));
372 vstore2(out13, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z));
373 vstore2(out20, 0, (__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z));
374 vstore2(out21, 0, (__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z));
375 vstore2(out22, 0, (__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z));
376 vstore2(out23, 0, (__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z));
377 vstore2(out30, 0, (__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z));
378 vstore2(out31, 0, (__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z));
379 vstore2(out32, 0, (__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z));
380 vstore2(out33, 0, (__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100381#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
382}
383
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100384/** This OpenCL kernel computes the input transform when the output tile is 4x4/4x1 or 1x4, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100385 *
386 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
387 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
388 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
389 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
390 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
391 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100392 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100393 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100394 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100395 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
396 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
397 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
398 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
399 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
400 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
401 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
402 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
403 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
404 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
405 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
406 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
407 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
408 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
409 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100410 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
411 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100412 */
413__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
414 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100415 TENSOR3D_DECLARATION(dst),
416 uint src_stride_w,
417 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100418{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100419 const int x = get_global_id(0);
420 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000421#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100422 const int z = get_global_id(2) % SRC_DEPTH;
423 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000424#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000425 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000426#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100427
428 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000429#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100430 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000431#else /* defined(SRC_DEPTH) */
432 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
433#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100434
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100435 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100436
437#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
438 // Row0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100439 VEC_DATA_TYPE(DATA_TYPE, 4)
440 d00 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
441 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
442 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
443 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
444 VEC_DATA_TYPE(DATA_TYPE, 2)
445 d01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(*((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
446 *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100447#else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
448 // Row0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100449 VEC_DATA_TYPE(DATA_TYPE, 4)
450 d00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
451 VEC_DATA_TYPE(DATA_TYPE, 2)
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000452 d01 = vload2(2, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100453#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
454
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100455 DATA_TYPE out0 = 0.0f;
456 DATA_TYPE out1 = 0.0f;
457 DATA_TYPE out2 = 0.0f;
458 DATA_TYPE out3 = 0.0f;
459 DATA_TYPE out4 = 0.0f;
460 DATA_TYPE out5 = 0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100461
462 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
463 out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
464 out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
465 out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
466 out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
467 out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
468 out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
469
470#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
471 // Row4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100472 VEC_DATA_TYPE(DATA_TYPE, 4)
473 d40 = vload4(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
474 VEC_DATA_TYPE(DATA_TYPE, 2)
475 d41 = vload2(2, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100476
477 // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100478 DATA_TYPE k0 = d41.s0;
479 DATA_TYPE k1 = d41.s0;
480 DATA_TYPE k2 = d41.s0;
481 DATA_TYPE k3 = d41.s0;
482 DATA_TYPE k4 = d41.s0;
483 DATA_TYPE k5 = 0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100484
485 k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
486 k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
487 k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
488 k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
489 k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
490 k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
491
492 out0 += k0;
493 out1 += k1;
494 out2 += k2;
495 out3 += k3;
496 out4 += k4;
497 out5 += k5;
498
499 // Row2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100500 VEC_DATA_TYPE(DATA_TYPE, 4)
501 d20 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
502 VEC_DATA_TYPE(DATA_TYPE, 2)
503 d21 = vload2(2, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100504
505 out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
506 out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
507 out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
508 out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
509 out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
510 out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
511#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
512
513 // Compute destination address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000514#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100515 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000516#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000517 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000518#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100519
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100520 uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100521
522 *(dst_addr) = out0;
523 dst_addr += dst_plane_stride;
524 *(dst_addr) = out1;
525 dst_addr += dst_plane_stride;
526 *(dst_addr) = out2;
527 dst_addr += dst_plane_stride;
528 *(dst_addr) = out3;
529 dst_addr += dst_plane_stride;
530 *(dst_addr) = out4;
531 dst_addr += dst_plane_stride;
532 *(dst_addr) = out5;
533 dst_addr += dst_plane_stride;
534
535#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100536 DATA_TYPE out6 = k0;
537 DATA_TYPE out7 = k1;
538 DATA_TYPE out8 = k2;
539 DATA_TYPE out9 = k3;
540 DATA_TYPE out10 = k4;
541 DATA_TYPE out11 = k5;
542 DATA_TYPE out12 = k0;
543 DATA_TYPE out13 = k1;
544 DATA_TYPE out14 = k2;
545 DATA_TYPE out15 = k3;
546 DATA_TYPE out16 = k4;
547 DATA_TYPE out17 = k5;
548 DATA_TYPE out18 = k0;
549 DATA_TYPE out19 = k1;
550 DATA_TYPE out20 = k2;
551 DATA_TYPE out21 = k3;
552 DATA_TYPE out22 = k4;
553 DATA_TYPE out23 = k5;
554 DATA_TYPE out24 = k0;
555 DATA_TYPE out25 = k1;
556 DATA_TYPE out26 = k2;
557 DATA_TYPE out27 = k3;
558 DATA_TYPE out28 = k4;
559 DATA_TYPE out29 = k5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100560
561 // Row1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100562 VEC_DATA_TYPE(DATA_TYPE, 4)
563 d10 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
564 VEC_DATA_TYPE(DATA_TYPE, 2)
565 d11 = vload2(2, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100566
567 // Row3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100568 VEC_DATA_TYPE(DATA_TYPE, 4)
569 d30 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
570 VEC_DATA_TYPE(DATA_TYPE, 2)
571 d31 = vload2(2, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100572
573 // Compute common parts for the channels between [6, 29]
574 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
575 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100576 DATA_TYPE part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
577 DATA_TYPE part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
578 DATA_TYPE part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
579 DATA_TYPE part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
580 DATA_TYPE part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
581 DATA_TYPE part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
582 DATA_TYPE part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
583 DATA_TYPE part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
584 DATA_TYPE part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
585 DATA_TYPE part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
586 DATA_TYPE part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
587 DATA_TYPE part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100588
589 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
590 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100591 DATA_TYPE part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
592 DATA_TYPE part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
593 DATA_TYPE part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
594 DATA_TYPE part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
595 DATA_TYPE part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
596 DATA_TYPE part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
597 DATA_TYPE part18 = part6 * 0.25f; // d20.s2 - d21.s0
598 DATA_TYPE part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
599 DATA_TYPE part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
600 DATA_TYPE part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
601 DATA_TYPE part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
602 DATA_TYPE part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100603
604 out6 += part0 - part1;
605 out12 += part0 + part1;
606 out7 += part2 + part3 + part4 + part5;
607 out8 += part2 - part3 + part4 - part5;
608 out13 += part2 + part3 - part4 - part5;
609 out14 += part2 - part3 - part4 + part5;
610 out9 += part6 + part7 + part8 + part9;
611 out10 += part6 - part7 + part8 - part9;
612 out15 += part6 - part7 - part8 + part9;
613 out16 += part6 + part7 - part8 - part9;
614 out11 += part10 + part11;
615 out17 += part10 - part11;
616
617 out18 += part13 - part12;
618 out24 += part13 + part12;
619 out19 += part14 + part15 + part16 + part17;
620 out20 += part14 - part15 + part16 - part17;
621 out25 += part14 - part15 - part16 + part17;
622 out26 += part14 + part15 - part16 - part17;
623 out21 += part18 + part19 + part20 + part21;
624 out22 += part18 - part19 + part20 - part21;
625 out27 += part18 - part19 - part20 + part21;
626 out28 += part18 + part19 - part20 - part21;
627 out23 += part22 + part23;
628 out29 += part22 - part23;
629
630 *(dst_addr) = out6;
631 dst_addr += dst_plane_stride;
632 *(dst_addr) = out7;
633 dst_addr += dst_plane_stride;
634 *(dst_addr) = out8;
635 dst_addr += dst_plane_stride;
636 *(dst_addr) = out9;
637 dst_addr += dst_plane_stride;
638 *(dst_addr) = out10;
639 dst_addr += dst_plane_stride;
640 *(dst_addr) = out11;
641 dst_addr += dst_plane_stride;
642 *(dst_addr) = out12;
643 dst_addr += dst_plane_stride;
644 *(dst_addr) = out13;
645 dst_addr += dst_plane_stride;
646 *(dst_addr) = out14;
647 dst_addr += dst_plane_stride;
648 *(dst_addr) = out15;
649 dst_addr += dst_plane_stride;
650 *(dst_addr) = out16;
651 dst_addr += dst_plane_stride;
652 *(dst_addr) = out17;
653 dst_addr += dst_plane_stride;
654
655 *(dst_addr) = out18;
656 dst_addr += dst_plane_stride;
657 *(dst_addr) = out19;
658 dst_addr += dst_plane_stride;
659 *(dst_addr) = out20;
660 dst_addr += dst_plane_stride;
661 *(dst_addr) = out21;
662 dst_addr += dst_plane_stride;
663 *(dst_addr) = out22;
664 dst_addr += dst_plane_stride;
665 *(dst_addr) = out23;
666 dst_addr += dst_plane_stride;
667 *(dst_addr) = out24;
668 dst_addr += dst_plane_stride;
669 *(dst_addr) = out25;
670 dst_addr += dst_plane_stride;
671 *(dst_addr) = out26;
672 dst_addr += dst_plane_stride;
673 *(dst_addr) = out27;
674 dst_addr += dst_plane_stride;
675 *(dst_addr) = out28;
676 dst_addr += dst_plane_stride;
677 *(dst_addr) = out29;
678 dst_addr += dst_plane_stride;
679
680 // Row5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100681 VEC_DATA_TYPE(DATA_TYPE, 4)
682 d50 = vload4(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
683 VEC_DATA_TYPE(DATA_TYPE, 2)
684 d51 = vload2(2, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100685
686 // Channels [30, 35]
687 out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
688 out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
689 out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
690 out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
691 out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
692 out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
693
694 *(dst_addr) = out0;
695 dst_addr += dst_plane_stride;
696 *(dst_addr) = out1;
697 dst_addr += dst_plane_stride;
698 *(dst_addr) = out2;
699 dst_addr += dst_plane_stride;
700 *(dst_addr) = out3;
701 dst_addr += dst_plane_stride;
702 *(dst_addr) = out4;
703 dst_addr += dst_plane_stride;
704 *(dst_addr) = out5;
705 dst_addr += dst_plane_stride;
706#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
707}
708
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100709/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NCHW
710 *
711 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
712 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
713 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
714 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
715 * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
716 * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
717 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
718 *
719 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
720 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
721 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
722 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
723 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
724 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
725 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
726 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
727 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
728 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
729 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
730 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
731 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
732 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
733 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
734 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
735 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
736 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
737 */
738__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
739 TENSOR3D_DECLARATION(src),
740 TENSOR3D_DECLARATION(dst),
741 uint src_stride_w,
742 uint dst_stride_w)
743{
744 const int x = get_global_id(0);
745 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000746#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100747 const int z = get_global_id(2) % SRC_DEPTH;
748 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000749#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000750 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000751#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100752
753 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000754#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100755 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000756#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000757 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000758#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100759 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
760
761 // Load input tile
762#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
763 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr));
764#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
765 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 8))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
766 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
767 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
768 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)),
769 *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
770 *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)),
771 *((__global DATA_TYPE *)(src_addr + 6 * src_stride_y)),
772 *((__global DATA_TYPE *)(src_addr + 7 * src_stride_y)));
773#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
774 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
775 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row1 = vload8(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
776 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row2 = vload8(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
777 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row3 = vload8(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
778 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row4 = vload8(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
779 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row5 = vload8(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
780 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row6 = vload8(0, (__global DATA_TYPE *)(src_addr + 6 * src_stride_y));
781 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row7 = vload8(0, (__global DATA_TYPE *)(src_addr + 7 * src_stride_y));
782#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
783
784 // Calculate common factors for intermediate tensor
785 VEC_DATA_TYPE(DATA_TYPE, 8)
Gian Marco Iodice534b8892021-04-01 16:17:16 +0100786 tmp0 = in_row0;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100787 VEC_DATA_TYPE(DATA_TYPE, 8)
788 comm_fact0 = 0.0f;
789
790#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena049989a2021-03-22 17:02:26 +0000791 comm_fact0 += in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4;
Gian Marco Iodice534b8892021-04-01 16:17:16 +0100792 tmp0 += -in_row6 + (DATA_TYPE)5.25f * in_row4 - (DATA_TYPE)5.25f * in_row2;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100793
794 VEC_DATA_TYPE(DATA_TYPE, 8)
Giorgio Arena049989a2021-03-22 17:02:26 +0000795 comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100796 VEC_DATA_TYPE(DATA_TYPE, 8)
Giorgio Arena049989a2021-03-22 17:02:26 +0000797 comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100798
Gian Marco Iodice534b8892021-04-01 16:17:16 +0100799 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 + comm_fact1;
800 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 - comm_fact1;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100801
Giorgio Arena049989a2021-03-22 17:02:26 +0000802 comm_fact0 = (DATA_TYPE)2.5f * in_row3;
803 comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.0f * in_row5;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100804
Gian Marco Iodice534b8892021-04-01 16:17:16 +0100805 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact1 + comm_fact2;
806 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 - comm_fact1;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100807
Giorgio Arena049989a2021-03-22 17:02:26 +0000808 comm_fact1 = (DATA_TYPE)2.0f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5;
809 comm_fact2 = (DATA_TYPE)4.0f * in_row2 - (DATA_TYPE)5.0f * in_row4 + in_row6;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100810
Gian Marco Iodice534b8892021-04-01 16:17:16 +0100811 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact1 + comm_fact2;
812 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact2 - comm_fact1;
813 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * in_row3 - (DATA_TYPE)5.25f * in_row5;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100814#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
815
816 // Calculate output rows (reuse comm_fact0 vector)
Gian Marco Iodice534b8892021-04-01 16:17:16 +0100817 VEC_DATA_TYPE(DATA_TYPE, 8)
818 out0;
819
820 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100821
822#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodice534b8892021-04-01 16:17:16 +0100823 VEC_DATA_TYPE(DATA_TYPE, 8)
824 out1, out2, out3, out4, out5, out6, out7;
825
826 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
827 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
828 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
829 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
830 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
831 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
832 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100833#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
834
835 // Store values across the channels
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000836#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100837 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000838#else /* defined(SRC_DEPTH) */
839 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
840#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100841
842 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
843 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
844 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
845 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
846 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
847 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
848 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
849 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
850
851#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
852 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
853 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
854 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
855 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
856 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
857 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
858 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
859 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
860 *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
861 *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
862 *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
863 *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
864 *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
865 *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
866 *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
867 *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
868 *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
869 *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
870 *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
871 *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
872 *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
873 *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
874 *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
875 *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
876 *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
877 *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
878 *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
879 *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
880 *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
881 *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
882 *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
883 *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
884 *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
885 *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
886 *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
887 *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
888 *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
889 *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
890 *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
891 *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
892 *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
893 *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
894 *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
895 *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
896 *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
897 *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
898 *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
899 *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
900 *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
901 *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
902 *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
903 *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
904 *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
905 *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
906 *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
907 *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
908#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
909}
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100910
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100911#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
912/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
913 *
914 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
915 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
916 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
917 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
918 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100919 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100920 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100921 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100922 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
923 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
924 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
925 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
926 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
927 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
928 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
929 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
930 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
931 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
932 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
933 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
934 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
935 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
936 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100937 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
938 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100939 */
940__kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
941 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100942 TENSOR3D_DECLARATION(dst),
943 uint src_stride_w,
944 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100945{
946 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
947 src_stride_x,
948 src_step_x,
949 src_stride_y,
950 src_step_y,
951 src_stride_z,
952 src_step_z,
953 src_offset_first_element_in_bytes,
954 dst_ptr,
955 dst_stride_x,
956 dst_step_x,
957 dst_stride_y,
958 dst_step_y,
959 dst_stride_z,
960 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100961 dst_offset_first_element_in_bytes,
962 src_stride_w,
963 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100964}
965
966/** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
967 *
968 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
969 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
970 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
971 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
972 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100973 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100974 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100975 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100976 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
977 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
978 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
979 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
980 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
981 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
982 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
983 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
984 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
985 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
986 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
987 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
988 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
989 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
990 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100991 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
992 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100993 */
994__kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
995 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100996 TENSOR3D_DECLARATION(dst),
997 uint src_stride_w,
998 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100999{
1000 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
1001 src_stride_x,
1002 src_step_x,
1003 src_stride_y,
1004 src_step_y,
1005 src_stride_z,
1006 src_step_z,
1007 src_offset_first_element_in_bytes,
1008 dst_ptr,
1009 dst_stride_x,
1010 dst_step_x,
1011 dst_stride_y,
1012 dst_step_y,
1013 dst_stride_z,
1014 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001015 dst_offset_first_element_in_bytes,
1016 src_stride_w,
1017 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001018}
1019
1020/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
1021 *
1022 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1023 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1024 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
1025 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1026 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001027 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001028 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001029 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001030 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1031 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1032 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1033 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1034 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1035 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1036 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1037 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1038 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1039 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1040 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1041 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1042 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1043 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1044 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001045 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1046 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001047 */
1048__kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
1049 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001050 TENSOR3D_DECLARATION(dst),
1051 uint src_stride_w,
1052 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001053{
1054 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
1055 src_stride_x,
1056 src_step_x,
1057 src_stride_y,
1058 src_step_y,
1059 src_stride_z,
1060 src_step_z,
1061 src_offset_first_element_in_bytes,
1062 dst_ptr,
1063 dst_stride_x,
1064 dst_step_x,
1065 dst_stride_y,
1066 dst_step_y,
1067 dst_stride_z,
1068 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001069 dst_offset_first_element_in_bytes,
1070 src_stride_w,
1071 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001072}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001073
1074/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 when the data layout is NCHW
1075 *
1076 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1077 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1078 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1079 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1080 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001081 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001082 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001083 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001084 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1085 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1086 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1087 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1088 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1089 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1090 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1091 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1092 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1093 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1094 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1095 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1096 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1097 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1098 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001099 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1100 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001101 */
1102__kernel void winograd_input_transform_4x1_5x1_stepz1_nchw(
1103 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001104 TENSOR3D_DECLARATION(dst),
1105 uint src_stride_w,
1106 uint dst_stride_w)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001107{
1108 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
1109 src_stride_x,
1110 src_step_x,
1111 src_stride_y,
1112 src_step_y,
1113 src_stride_z,
1114 src_step_z,
1115 src_offset_first_element_in_bytes,
1116 dst_ptr,
1117 dst_stride_x,
1118 dst_step_x,
1119 dst_stride_y,
1120 dst_step_y,
1121 dst_stride_z,
1122 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001123 dst_offset_first_element_in_bytes,
1124 src_stride_w,
1125 dst_stride_w);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001126}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001127#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1128
1129#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1130/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
1131 *
1132 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1133 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1134 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1135 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1136 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001137 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001138 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001139 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001140 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1141 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1142 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1143 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1144 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1145 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1146 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1147 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1148 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1149 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1150 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1151 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1152 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1153 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1154 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001155 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1156 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001157 */
1158__kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
1159 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001160 TENSOR3D_DECLARATION(dst),
1161 uint src_stride_w,
1162 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001163{
1164 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
1165 src_stride_x,
1166 src_step_x,
1167 src_stride_y,
1168 src_step_y,
1169 src_stride_z,
1170 src_step_z,
1171 src_offset_first_element_in_bytes,
1172 dst_ptr,
1173 dst_stride_x,
1174 dst_step_x,
1175 dst_stride_y,
1176 dst_step_y,
1177 dst_stride_z,
1178 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001179 dst_offset_first_element_in_bytes,
1180 src_stride_w,
1181 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001182}
1183
1184/** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
1185 *
1186 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1187 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1188 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1189 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1190 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001191 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001192 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001193 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001194 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1195 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1196 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1197 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1198 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1199 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1200 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1201 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1202 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1203 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1204 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1205 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1206 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1207 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1208 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001209 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1210 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001211 */
1212__kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
1213 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001214 TENSOR3D_DECLARATION(dst),
1215 uint src_stride_w,
1216 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001217{
1218 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
1219 src_stride_x,
1220 src_step_x,
1221 src_stride_y,
1222 src_step_y,
1223 src_stride_z,
1224 src_step_z,
1225 src_offset_first_element_in_bytes,
1226 dst_ptr,
1227 dst_stride_x,
1228 dst_step_x,
1229 dst_stride_y,
1230 dst_step_y,
1231 dst_stride_z,
1232 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001233 dst_offset_first_element_in_bytes,
1234 src_stride_w,
1235 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001236}
1237
1238/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
1239 *
1240 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1241 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1242 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1243 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
1244 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001245 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001246 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001247 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001248 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1249 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1250 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1251 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1252 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1253 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1254 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1255 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1256 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1257 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1258 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1259 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1260 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1261 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1262 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001263 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1264 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001265 */
1266__kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
1267 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001268 TENSOR3D_DECLARATION(dst),
1269 uint src_stride_w,
1270 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001271{
1272 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
1273 src_stride_x,
1274 src_step_x,
1275 src_stride_y,
1276 src_step_y,
1277 src_stride_z,
1278 src_step_z,
1279 src_offset_first_element_in_bytes,
1280 dst_ptr,
1281 dst_stride_x,
1282 dst_step_x,
1283 dst_stride_y,
1284 dst_step_y,
1285 dst_stride_z,
1286 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001287 dst_offset_first_element_in_bytes,
1288 src_stride_w,
1289 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001290}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001291
1292/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4
1293 *
1294 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1295 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1296 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
1297 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
1298 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001299 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001300 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001301 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001302 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1303 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1304 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1305 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1306 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1307 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1308 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1309 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1310 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1311 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1312 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1313 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1314 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1315 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1316 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001317 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1318 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001319 */
1320__kernel void winograd_input_transform_1x4_1x5_stepz1_nchw(
1321 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001322 TENSOR3D_DECLARATION(dst),
1323 uint src_stride_w,
1324 uint dst_stride_w)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001325{
1326 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
1327 src_stride_x,
1328 src_step_x,
1329 src_stride_y,
1330 src_step_y,
1331 src_stride_z,
1332 src_step_z,
1333 src_offset_first_element_in_bytes,
1334 dst_ptr,
1335 dst_stride_x,
1336 dst_step_x,
1337 dst_stride_y,
1338 dst_step_y,
1339 dst_stride_z,
1340 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001341 dst_offset_first_element_in_bytes,
1342 src_stride_w,
1343 dst_stride_w);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01001344}
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001345#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001346#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)