blob: 94f3772495eaa7f4e2205fb49d9e426031cbcc22 [file] [log] [blame]
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001/*
Giorgio Arena049989a2021-03-22 17:02:26 +00002 * Copyright (c) 2018-2021 Arm Limited.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Giorgio Arena2d1a8352020-10-26 15:04:08 +000026#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(datatype, basename, y_cond, z_cond) \
27 ({ \
28 basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s0) && (z_cond))); \
29 basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s1) && (z_cond))); \
30 basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s2) && (z_cond))); \
31 basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s3) && (z_cond))); \
32 basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))(((y_cond##1).s0) && (z_cond))); \
33 basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))(((y_cond##1).s1) && (z_cond))); \
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +010034 })
35
Giorgio Arena2d1a8352020-10-26 15:04:08 +000036#define FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(datatype, basename, y_cond, z_cond) \
37 ({ \
38 basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s0))); \
39 basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s1))); \
40 basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s2))); \
41 basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s3))); \
42 basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##1).s0))); \
43 basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##1).s1))); \
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +010044 })
45
Giorgio Arena2d1a8352020-10-26 15:04:08 +000046#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(datatype, basename, y_cond, z_cond) \
47 ({ \
48 basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s0) && (z_cond))); \
49 basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s1) && (z_cond))); \
50 basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s2) && (z_cond))); \
51 basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s3) && (z_cond))); \
52 basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s4) && (z_cond))); \
53 basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s5) && (z_cond))); \
54 basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s6) && (z_cond))); \
55 basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))(((y_cond##0).s7) && (z_cond))); \
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +010056 })
57
Giorgio Arena2d1a8352020-10-26 15:04:08 +000058#define FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(datatype, basename, y_cond, z_cond) \
59 ({ \
60 basename##0 = select((datatype)0, basename##0, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s0))); \
61 basename##1 = select((datatype)0, basename##1, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s1))); \
62 basename##2 = select((datatype)0, basename##2, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s2))); \
63 basename##3 = select((datatype)0, basename##3, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s3))); \
64 basename##4 = select((datatype)0, basename##4, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s4))); \
65 basename##5 = select((datatype)0, basename##5, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s5))); \
66 basename##6 = select((datatype)0, basename##6, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s6))); \
67 basename##7 = select((datatype)0, basename##7, (SELECT_DATA_TYPE(datatype))((y_cond) && ((z_cond##0).s7))); \
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +010068 })
69
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +000070// out = B^T * in, B^T is defined as for F(4x4,5x5) input transformation
71#define BT_MULTIPLY_4x4_5x5(out, in, comm_fact0, comm_fact1, DATA_TYPE) \
72 ({ \
73 comm_fact0 = in##2 + in##6 - (DATA_TYPE)4.25f * in##4; \
74 comm_fact1 = in##1 + in##5 - (DATA_TYPE)4.25f * in##3; \
75 out##0 += (DATA_TYPE)5.25f * (in##4 - in##2) - in##6; \
76 out##7 += (DATA_TYPE)5.25f * (in##3 - in##5) - in##1; \
77 out##1 = comm_fact0 + comm_fact1; \
78 out##2 = comm_fact0 - comm_fact1; \
Gian Marco Iodiced28b7512018-07-06 12:59:28 +010079 \
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +000080 comm_fact0 = (DATA_TYPE)0.25f * in##2 - (DATA_TYPE)1.25f * in##4 + in##6; \
81 comm_fact1 = (DATA_TYPE)0.5f * in##1 - (DATA_TYPE)2.5f * in##3 + (DATA_TYPE)2.f * in##5; \
82 out##3 = comm_fact0 + comm_fact1; \
83 out##4 = comm_fact0 - comm_fact1; \
84 \
85 comm_fact0 = (DATA_TYPE)4.f * in##2 - (DATA_TYPE)5.f * in##4 + in##6; \
86 comm_fact1 = (DATA_TYPE)2.f * in##1 - (DATA_TYPE)2.5f * in##3 + (DATA_TYPE)0.5f * in##5; \
87 out##5 = comm_fact0 + comm_fact1; \
88 out##6 = comm_fact0 - comm_fact1; \
89 })
90
91#define OUTPUT_ROW_4x4_5x5(out, comm_fact) \
92 ({ \
93 comm_fact.s2 = 2.5f * out.s3; \
94 comm_fact.s1 = out.s1 - 4.25f * out.s3 + out.s5; \
95 comm_fact.s0 = out.s2 - 4.25f * out.s4 + out.s6; \
96 comm_fact.s4 = 0.25f * out.s2 - 1.25f * out.s4 + out.s6; \
97 comm_fact.s5 = 4.f * out.s2 - 5.f * out.s4 + out.s6; \
98 comm_fact.s3 = 0.5f * out.s1 + 2.f * out.s5 - comm_fact.s2; \
99 comm_fact.s6 = 2.f * out.s1 + 0.5f * out.s5 - comm_fact.s2; \
100 \
101 out.s0 += 5.25f * (out.s4 - out.s2) - out.s6; \
102 out.s7 += 5.25f * (out.s3 - out.s5) - out.s1; \
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100103 out.s1 = comm_fact.s0 + comm_fact.s1; \
104 out.s2 = comm_fact.s0 - comm_fact.s1; \
105 out.s3 = comm_fact.s3 + comm_fact.s4; \
106 out.s4 = comm_fact.s4 - comm_fact.s3; \
107 out.s5 = comm_fact.s5 + comm_fact.s6; \
108 out.s6 = comm_fact.s5 - comm_fact.s6; \
Gian Marco Iodiced28b7512018-07-06 12:59:28 +0100109 })
110
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000111#define OUTPUT_ROW_2x2_7x7(out, tmp, comm_fact) \
112 ({ \
113 comm_fact.s0 = 36.0f * tmp.s2 - 13.0f * tmp.s4 + tmp.s6; \
114 comm_fact.s1 = 36.0f * tmp.s1 - 13.0f * tmp.s3 + 1.0f * tmp.s5; \
115 comm_fact.s2 = 9.0f * tmp.s2 - 10.0f * tmp.s4 + tmp.s6; \
116 comm_fact.s3 = 18.0f * tmp.s1 - 20.0f * tmp.s3 + 2.0f * tmp.s5; \
117 comm_fact.s4 = 4.0f * tmp.s2 - 5.0f * tmp.s4 + tmp.s6; \
118 comm_fact.s5 = 12.0f * tmp.s1 - 15.0f * tmp.s3 + 3.0f * tmp.s5; \
119 out.s0 = -36.0f * tmp.s0 + 49.0f * tmp.s2 + -14.0f * tmp.s4 + tmp.s6; \
120 out.s1 = comm_fact.s0 - comm_fact.s1; \
121 out.s2 = comm_fact.s0 + comm_fact.s1; \
122 out.s3 = comm_fact.s2 - comm_fact.s3; \
123 out.s4 = comm_fact.s2 + comm_fact.s3; \
124 out.s5 = comm_fact.s4 - comm_fact.s5; \
125 out.s6 = comm_fact.s4 + comm_fact.s5; \
126 out.s7 = -36.0f * tmp.s1 + 0.0f * tmp.s2 + 49.0f * tmp.s3 - 14.0f * tmp.s5 + tmp.s7; \
127 })
128
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100129#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
130/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
131 *
132 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
133 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
134 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
135 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
136 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
137 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100138 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100139 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100140 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100141 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
142 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
143 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
144 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
145 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
146 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
147 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
148 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
149 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
150 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
151 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
152 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
153 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
154 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
155 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100156 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
157 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100158 */
159__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
160 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100161 TENSOR3D_DECLARATION(dst),
162 uint src_stride_w,
163 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100164{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100165 const int x = get_global_id(0);
166 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000167#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100168 const int z = get_global_id(2) % SRC_DEPTH;
169 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000170#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000171 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000172#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100173
174 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000175#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100176 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000177#else /* defined(SRC_DEPTH) */
178 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
179#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100180
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100181 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100182
183#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100184 VEC_DATA_TYPE(DATA_TYPE, 4)
185 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100186#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100187 VEC_DATA_TYPE(DATA_TYPE, 4)
188 in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
189 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
190 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
191 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100192#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100193 VEC_DATA_TYPE(DATA_TYPE, 4)
194 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
195 VEC_DATA_TYPE(DATA_TYPE, 4)
196 in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
197 VEC_DATA_TYPE(DATA_TYPE, 4)
198 in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
199 VEC_DATA_TYPE(DATA_TYPE, 4)
200 in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100201#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
202
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100203 VEC_DATA_TYPE(DATA_TYPE, 4)
204 tmp0 = in_row0;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100205
206#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
207 tmp0 -= in_row2;
208#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
209
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100210 DATA_TYPE out00 = tmp0.s0 - tmp0.s2;
211 DATA_TYPE out01 = tmp0.s1 + tmp0.s2;
212 DATA_TYPE out02 = tmp0.s2 - tmp0.s1;
213 DATA_TYPE out03 = tmp0.s1 - tmp0.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100214
215#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100216 VEC_DATA_TYPE(DATA_TYPE, 4)
217 tmp1 = in_row1 + in_row2;
218 VEC_DATA_TYPE(DATA_TYPE, 4)
219 tmp2 = in_row2 - in_row1;
220 VEC_DATA_TYPE(DATA_TYPE, 4)
221 tmp3 = in_row1 - in_row3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100222
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100223 DATA_TYPE out10 = tmp1.s0 - tmp1.s2;
224 DATA_TYPE out11 = tmp1.s1 + tmp1.s2;
225 DATA_TYPE out12 = tmp1.s2 - tmp1.s1;
226 DATA_TYPE out13 = tmp1.s1 - tmp1.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100227
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100228 DATA_TYPE out20 = tmp2.s0 - tmp2.s2;
229 DATA_TYPE out21 = tmp2.s1 + tmp2.s2;
230 DATA_TYPE out22 = tmp2.s2 - tmp2.s1;
231 DATA_TYPE out23 = tmp2.s1 - tmp2.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100232
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100233 DATA_TYPE out30 = tmp3.s0 - tmp3.s2;
234 DATA_TYPE out31 = tmp3.s1 + tmp3.s2;
235 DATA_TYPE out32 = tmp3.s2 - tmp3.s1;
236 DATA_TYPE out33 = tmp3.s1 - tmp3.s3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100237#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
238
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000239#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100240 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000241#else /* defined(SRC_DEPTH) */
242 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
243#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100244
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100245 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
246 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
247 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
248 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100249
250#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100251 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out10;
252 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out11;
253 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out12;
254 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out13;
255 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out20;
256 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out21;
257 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out22;
258 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out23;
259 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out30;
260 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out31;
261 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out32;
262 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out33;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100263#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
264}
265
266/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
267 *
268 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
269 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
270 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
271 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
272 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
273 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100274 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100275 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100276 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100277 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
278 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
279 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
280 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
281 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
282 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
283 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
284 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
285 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
286 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
287 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
288 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
289 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
290 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
291 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100292 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
293 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100294 */
295__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
296 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100297 TENSOR3D_DECLARATION(dst),
298 uint src_stride_w,
299 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100300{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100301 const int x = get_global_id(0);
302 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000303#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100304 const int z = (get_global_id(2) * 2) % SRC_DEPTH;
305 const int b = (get_global_id(2) * 2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000306#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000307 const int z = get_global_id(2) * 2;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000308#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100309
310 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000311#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100312 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000313#else /* defined(SRC_DEPTH) */
314 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
315#endif /* defined(SRC_DEPTH) */
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100316 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100317
318#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100319 VEC_DATA_TYPE(DATA_TYPE, 4)
320 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100321#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100322 VEC_DATA_TYPE(DATA_TYPE, 4)
323 in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
324 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
325 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
326 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100327#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100328 VEC_DATA_TYPE(DATA_TYPE, 4)
329 in_row0 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
330 VEC_DATA_TYPE(DATA_TYPE, 4)
331 in_row1 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
332 VEC_DATA_TYPE(DATA_TYPE, 4)
333 in_row2 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
334 VEC_DATA_TYPE(DATA_TYPE, 4)
335 in_row3 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100336#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
337
338 src_addr += src_stride_z;
339#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100340 VEC_DATA_TYPE(DATA_TYPE, 4)
341 in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100342#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100343 VEC_DATA_TYPE(DATA_TYPE, 4)
344 in_row4 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
345 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
346 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
347 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100348#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100349 VEC_DATA_TYPE(DATA_TYPE, 4)
350 in_row4 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
351 VEC_DATA_TYPE(DATA_TYPE, 4)
352 in_row5 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
353 VEC_DATA_TYPE(DATA_TYPE, 4)
354 in_row6 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
355 VEC_DATA_TYPE(DATA_TYPE, 4)
356 in_row7 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100357#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
358
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100359 VEC_DATA_TYPE(DATA_TYPE, 4)
360 tmp0 = in_row0;
361 VEC_DATA_TYPE(DATA_TYPE, 4)
362 tmp4 = in_row4;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100363
364#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
365 tmp0 -= in_row2;
366 tmp4 -= in_row6;
367#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
368
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100369 VEC_DATA_TYPE(DATA_TYPE, 2)
370 out00 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
371 VEC_DATA_TYPE(DATA_TYPE, 2)
372 out01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
373 VEC_DATA_TYPE(DATA_TYPE, 2)
374 out02 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
375 VEC_DATA_TYPE(DATA_TYPE, 2)
376 out03 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100377
378#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100379 VEC_DATA_TYPE(DATA_TYPE, 4)
380 tmp1 = in_row1 + in_row2;
381 VEC_DATA_TYPE(DATA_TYPE, 4)
382 tmp2 = in_row2 - in_row1;
383 VEC_DATA_TYPE(DATA_TYPE, 4)
384 tmp3 = in_row1 - in_row3;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100385
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100386 VEC_DATA_TYPE(DATA_TYPE, 4)
387 tmp5 = in_row5 + in_row6;
388 VEC_DATA_TYPE(DATA_TYPE, 4)
389 tmp6 = in_row6 - in_row5;
390 VEC_DATA_TYPE(DATA_TYPE, 4)
391 tmp7 = in_row5 - in_row7;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100392
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100393 VEC_DATA_TYPE(DATA_TYPE, 2)
394 out10 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
395 VEC_DATA_TYPE(DATA_TYPE, 2)
396 out11 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
397 VEC_DATA_TYPE(DATA_TYPE, 2)
398 out12 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
399 VEC_DATA_TYPE(DATA_TYPE, 2)
400 out13 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100401
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100402 VEC_DATA_TYPE(DATA_TYPE, 2)
403 out20 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
404 VEC_DATA_TYPE(DATA_TYPE, 2)
405 out21 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
406 VEC_DATA_TYPE(DATA_TYPE, 2)
407 out22 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
408 VEC_DATA_TYPE(DATA_TYPE, 2)
409 out23 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100410
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100411 VEC_DATA_TYPE(DATA_TYPE, 2)
412 out30 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
413 VEC_DATA_TYPE(DATA_TYPE, 2)
414 out31 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
415 VEC_DATA_TYPE(DATA_TYPE, 2)
416 out32 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
417 VEC_DATA_TYPE(DATA_TYPE, 2)
418 out33 = (VEC_DATA_TYPE(DATA_TYPE, 2))(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100419#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
420
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000421#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100422 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000423#else /* defined(SRC_DEPTH) */
424 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
425#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100426
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100427 vstore2(out00, 0, (__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z));
428 vstore2(out01, 0, (__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z));
429 vstore2(out02, 0, (__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z));
430 vstore2(out03, 0, (__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100431
432#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100433 vstore2(out10, 0, (__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z));
434 vstore2(out11, 0, (__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z));
435 vstore2(out12, 0, (__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z));
436 vstore2(out13, 0, (__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z));
437 vstore2(out20, 0, (__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z));
438 vstore2(out21, 0, (__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z));
439 vstore2(out22, 0, (__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z));
440 vstore2(out23, 0, (__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z));
441 vstore2(out30, 0, (__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z));
442 vstore2(out31, 0, (__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z));
443 vstore2(out32, 0, (__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z));
444 vstore2(out33, 0, (__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100445#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
446}
447
Gian Marco Iodice876be2a2018-07-03 12:22:09 +0100448/** This OpenCL kernel computes the input transform when the output tile is 4x4/4x1 or 1x4, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100449 *
450 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
451 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
452 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
453 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
454 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
455 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100456 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100457 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100458 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100459 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
460 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
461 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
462 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
463 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
464 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
465 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
466 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
467 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
468 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
469 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
470 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
471 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
472 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
473 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100474 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
475 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100476 */
477__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
478 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100479 TENSOR3D_DECLARATION(dst),
480 uint src_stride_w,
481 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100482{
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100483 const int x = get_global_id(0);
484 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000485#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100486 const int z = get_global_id(2) % SRC_DEPTH;
487 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000488#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000489 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000490#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100491
492 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000493#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100494 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000495#else /* defined(SRC_DEPTH) */
496 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
497#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100498
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100499 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100500
501#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
502 // Row0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100503 VEC_DATA_TYPE(DATA_TYPE, 4)
504 d00 = (VEC_DATA_TYPE(DATA_TYPE, 4))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
505 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
506 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
507 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)));
508 VEC_DATA_TYPE(DATA_TYPE, 2)
509 d01 = (VEC_DATA_TYPE(DATA_TYPE, 2))(*((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
510 *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100511#else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
512 // Row0
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100513 VEC_DATA_TYPE(DATA_TYPE, 4)
514 d00 = vload4(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
515 VEC_DATA_TYPE(DATA_TYPE, 2)
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000516 d01 = vload2(2, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100517#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
518
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100519 DATA_TYPE out0 = 0.0f;
520 DATA_TYPE out1 = 0.0f;
521 DATA_TYPE out2 = 0.0f;
522 DATA_TYPE out3 = 0.0f;
523 DATA_TYPE out4 = 0.0f;
524 DATA_TYPE out5 = 0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100525
526 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
527 out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
528 out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
529 out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
530 out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
531 out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
532 out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
533
534#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
535 // Row4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100536 VEC_DATA_TYPE(DATA_TYPE, 4)
537 d40 = vload4(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
538 VEC_DATA_TYPE(DATA_TYPE, 2)
539 d41 = vload2(2, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100540
541 // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100542 DATA_TYPE k0 = d41.s0;
543 DATA_TYPE k1 = d41.s0;
544 DATA_TYPE k2 = d41.s0;
545 DATA_TYPE k3 = d41.s0;
546 DATA_TYPE k4 = d41.s0;
547 DATA_TYPE k5 = 0.0f;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100548
549 k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
550 k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
551 k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
552 k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
553 k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
554 k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
555
556 out0 += k0;
557 out1 += k1;
558 out2 += k2;
559 out3 += k3;
560 out4 += k4;
561 out5 += k5;
562
563 // Row2
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100564 VEC_DATA_TYPE(DATA_TYPE, 4)
565 d20 = vload4(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
566 VEC_DATA_TYPE(DATA_TYPE, 2)
567 d21 = vload2(2, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100568
569 out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
570 out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
571 out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
572 out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
573 out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
574 out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
575#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
576
577 // Compute destination address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000578#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100579 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000580#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000581 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000582#endif /* defined(SRC_DEPTH) */
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100583
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100584 uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100585
586 *(dst_addr) = out0;
587 dst_addr += dst_plane_stride;
588 *(dst_addr) = out1;
589 dst_addr += dst_plane_stride;
590 *(dst_addr) = out2;
591 dst_addr += dst_plane_stride;
592 *(dst_addr) = out3;
593 dst_addr += dst_plane_stride;
594 *(dst_addr) = out4;
595 dst_addr += dst_plane_stride;
596 *(dst_addr) = out5;
597 dst_addr += dst_plane_stride;
598
599#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100600 DATA_TYPE out6 = k0;
601 DATA_TYPE out7 = k1;
602 DATA_TYPE out8 = k2;
603 DATA_TYPE out9 = k3;
604 DATA_TYPE out10 = k4;
605 DATA_TYPE out11 = k5;
606 DATA_TYPE out12 = k0;
607 DATA_TYPE out13 = k1;
608 DATA_TYPE out14 = k2;
609 DATA_TYPE out15 = k3;
610 DATA_TYPE out16 = k4;
611 DATA_TYPE out17 = k5;
612 DATA_TYPE out18 = k0;
613 DATA_TYPE out19 = k1;
614 DATA_TYPE out20 = k2;
615 DATA_TYPE out21 = k3;
616 DATA_TYPE out22 = k4;
617 DATA_TYPE out23 = k5;
618 DATA_TYPE out24 = k0;
619 DATA_TYPE out25 = k1;
620 DATA_TYPE out26 = k2;
621 DATA_TYPE out27 = k3;
622 DATA_TYPE out28 = k4;
623 DATA_TYPE out29 = k5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100624
625 // Row1
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100626 VEC_DATA_TYPE(DATA_TYPE, 4)
627 d10 = vload4(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
628 VEC_DATA_TYPE(DATA_TYPE, 2)
629 d11 = vload2(2, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100630
631 // Row3
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100632 VEC_DATA_TYPE(DATA_TYPE, 4)
633 d30 = vload4(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
634 VEC_DATA_TYPE(DATA_TYPE, 2)
635 d31 = vload2(2, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100636
637 // Compute common parts for the channels between [6, 29]
638 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
639 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100640 DATA_TYPE part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
641 DATA_TYPE part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
642 DATA_TYPE part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
643 DATA_TYPE part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
644 DATA_TYPE part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
645 DATA_TYPE part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
646 DATA_TYPE part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
647 DATA_TYPE part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
648 DATA_TYPE part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
649 DATA_TYPE part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
650 DATA_TYPE part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
651 DATA_TYPE part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100652
653 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
654 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100655 DATA_TYPE part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
656 DATA_TYPE part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
657 DATA_TYPE part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
658 DATA_TYPE part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
659 DATA_TYPE part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
660 DATA_TYPE part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
661 DATA_TYPE part18 = part6 * 0.25f; // d20.s2 - d21.s0
662 DATA_TYPE part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
663 DATA_TYPE part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
664 DATA_TYPE part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
665 DATA_TYPE part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
666 DATA_TYPE part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100667
668 out6 += part0 - part1;
669 out12 += part0 + part1;
670 out7 += part2 + part3 + part4 + part5;
671 out8 += part2 - part3 + part4 - part5;
672 out13 += part2 + part3 - part4 - part5;
673 out14 += part2 - part3 - part4 + part5;
674 out9 += part6 + part7 + part8 + part9;
675 out10 += part6 - part7 + part8 - part9;
676 out15 += part6 - part7 - part8 + part9;
677 out16 += part6 + part7 - part8 - part9;
678 out11 += part10 + part11;
679 out17 += part10 - part11;
680
681 out18 += part13 - part12;
682 out24 += part13 + part12;
683 out19 += part14 + part15 + part16 + part17;
684 out20 += part14 - part15 + part16 - part17;
685 out25 += part14 - part15 - part16 + part17;
686 out26 += part14 + part15 - part16 - part17;
687 out21 += part18 + part19 + part20 + part21;
688 out22 += part18 - part19 + part20 - part21;
689 out27 += part18 - part19 - part20 + part21;
690 out28 += part18 + part19 - part20 - part21;
691 out23 += part22 + part23;
692 out29 += part22 - part23;
693
694 *(dst_addr) = out6;
695 dst_addr += dst_plane_stride;
696 *(dst_addr) = out7;
697 dst_addr += dst_plane_stride;
698 *(dst_addr) = out8;
699 dst_addr += dst_plane_stride;
700 *(dst_addr) = out9;
701 dst_addr += dst_plane_stride;
702 *(dst_addr) = out10;
703 dst_addr += dst_plane_stride;
704 *(dst_addr) = out11;
705 dst_addr += dst_plane_stride;
706 *(dst_addr) = out12;
707 dst_addr += dst_plane_stride;
708 *(dst_addr) = out13;
709 dst_addr += dst_plane_stride;
710 *(dst_addr) = out14;
711 dst_addr += dst_plane_stride;
712 *(dst_addr) = out15;
713 dst_addr += dst_plane_stride;
714 *(dst_addr) = out16;
715 dst_addr += dst_plane_stride;
716 *(dst_addr) = out17;
717 dst_addr += dst_plane_stride;
718
719 *(dst_addr) = out18;
720 dst_addr += dst_plane_stride;
721 *(dst_addr) = out19;
722 dst_addr += dst_plane_stride;
723 *(dst_addr) = out20;
724 dst_addr += dst_plane_stride;
725 *(dst_addr) = out21;
726 dst_addr += dst_plane_stride;
727 *(dst_addr) = out22;
728 dst_addr += dst_plane_stride;
729 *(dst_addr) = out23;
730 dst_addr += dst_plane_stride;
731 *(dst_addr) = out24;
732 dst_addr += dst_plane_stride;
733 *(dst_addr) = out25;
734 dst_addr += dst_plane_stride;
735 *(dst_addr) = out26;
736 dst_addr += dst_plane_stride;
737 *(dst_addr) = out27;
738 dst_addr += dst_plane_stride;
739 *(dst_addr) = out28;
740 dst_addr += dst_plane_stride;
741 *(dst_addr) = out29;
742 dst_addr += dst_plane_stride;
743
744 // Row5
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100745 VEC_DATA_TYPE(DATA_TYPE, 4)
746 d50 = vload4(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
747 VEC_DATA_TYPE(DATA_TYPE, 2)
748 d51 = vload2(2, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100749
750 // Channels [30, 35]
751 out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
752 out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
753 out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
754 out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
755 out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
756 out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
757
758 *(dst_addr) = out0;
759 dst_addr += dst_plane_stride;
760 *(dst_addr) = out1;
761 dst_addr += dst_plane_stride;
762 *(dst_addr) = out2;
763 dst_addr += dst_plane_stride;
764 *(dst_addr) = out3;
765 dst_addr += dst_plane_stride;
766 *(dst_addr) = out4;
767 dst_addr += dst_plane_stride;
768 *(dst_addr) = out5;
769 dst_addr += dst_plane_stride;
770#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
771}
772
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100773/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NCHW
774 *
775 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
776 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
777 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
778 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
779 * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
780 * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
781 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
782 *
783 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
784 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
785 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
786 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
787 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
788 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
789 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
790 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
791 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
792 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
793 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
794 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
795 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
796 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
797 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
798 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
799 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
800 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
801 */
802__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
803 TENSOR3D_DECLARATION(src),
804 TENSOR3D_DECLARATION(dst),
805 uint src_stride_w,
806 uint dst_stride_w)
807{
808 const int x = get_global_id(0);
809 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000810#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100811 const int z = get_global_id(2) % SRC_DEPTH;
812 const int b = get_global_id(2) / SRC_DEPTH;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000813#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000814 const int z = get_global_id(2);
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000815#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100816
817 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000818#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100819 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z + b * src_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000820#else /* defined(SRC_DEPTH) */
Michele Di Giorgiof955d512019-02-27 14:26:51 +0000821 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(DATA_TYPE) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000822#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100823 src_addr = src_addr - ((int)PAD_LEFT * sizeof(DATA_TYPE)) - ((int)PAD_TOP * src_stride_y);
824
825 // Load input tile
826#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
827 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr));
828#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
829 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = (VEC_DATA_TYPE(DATA_TYPE, 8))(*((__global DATA_TYPE *)(src_addr + 0 * src_stride_y)),
830 *((__global DATA_TYPE *)(src_addr + 1 * src_stride_y)),
831 *((__global DATA_TYPE *)(src_addr + 2 * src_stride_y)),
832 *((__global DATA_TYPE *)(src_addr + 3 * src_stride_y)),
833 *((__global DATA_TYPE *)(src_addr + 4 * src_stride_y)),
834 *((__global DATA_TYPE *)(src_addr + 5 * src_stride_y)),
835 *((__global DATA_TYPE *)(src_addr + 6 * src_stride_y)),
836 *((__global DATA_TYPE *)(src_addr + 7 * src_stride_y)));
837#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
838 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row0 = vload8(0, (__global DATA_TYPE *)(src_addr + 0 * src_stride_y));
839 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row1 = vload8(0, (__global DATA_TYPE *)(src_addr + 1 * src_stride_y));
840 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row2 = vload8(0, (__global DATA_TYPE *)(src_addr + 2 * src_stride_y));
841 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row3 = vload8(0, (__global DATA_TYPE *)(src_addr + 3 * src_stride_y));
842 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row4 = vload8(0, (__global DATA_TYPE *)(src_addr + 4 * src_stride_y));
843 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row5 = vload8(0, (__global DATA_TYPE *)(src_addr + 5 * src_stride_y));
844 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row6 = vload8(0, (__global DATA_TYPE *)(src_addr + 6 * src_stride_y));
845 const VEC_DATA_TYPE(DATA_TYPE, 8) in_row7 = vload8(0, (__global DATA_TYPE *)(src_addr + 7 * src_stride_y));
846#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
847
848 // Calculate common factors for intermediate tensor
849 VEC_DATA_TYPE(DATA_TYPE, 8)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +0000850 out0 = in_row0;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100851 VEC_DATA_TYPE(DATA_TYPE, 8)
852 comm_fact0 = 0.0f;
853
854#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +0000855 VEC_DATA_TYPE(DATA_TYPE, 8)
856 out1, out2, out3, out4, out5, out6, out7;
Giorgio Arena049989a2021-03-22 17:02:26 +0000857 comm_fact0 += in_row2 + in_row6 - (DATA_TYPE)4.25f * in_row4;
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +0000858 out0 += -in_row6 + (DATA_TYPE)5.25f * (in_row4 - in_row2);
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100859
860 VEC_DATA_TYPE(DATA_TYPE, 8)
Giorgio Arena049989a2021-03-22 17:02:26 +0000861 comm_fact1 = in_row1 + in_row5 - (DATA_TYPE)4.25f * in_row3;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100862 VEC_DATA_TYPE(DATA_TYPE, 8)
Giorgio Arena049989a2021-03-22 17:02:26 +0000863 comm_fact2 = (DATA_TYPE)0.25f * in_row2 - (DATA_TYPE)1.25f * in_row4 + in_row6;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100864
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +0000865 out1 = comm_fact0 + comm_fact1;
866 out2 = comm_fact0 - comm_fact1;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100867
Giorgio Arena049989a2021-03-22 17:02:26 +0000868 comm_fact0 = (DATA_TYPE)2.5f * in_row3;
869 comm_fact1 = (DATA_TYPE)0.5f * in_row1 - comm_fact0 + (DATA_TYPE)2.0f * in_row5;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100870
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +0000871 out3 = comm_fact1 + comm_fact2;
872 out4 = comm_fact2 - comm_fact1;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100873
Giorgio Arena049989a2021-03-22 17:02:26 +0000874 comm_fact1 = (DATA_TYPE)2.0f * in_row1 - comm_fact0 + (DATA_TYPE)0.5f * in_row5;
875 comm_fact2 = (DATA_TYPE)4.0f * in_row2 - (DATA_TYPE)5.0f * in_row4 + in_row6;
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100876
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +0000877 out5 = comm_fact1 + comm_fact2;
878 out6 = comm_fact2 - comm_fact1;
879 out7 = in_row7 - in_row1 + (DATA_TYPE)5.25f * (in_row3 - in_row5);
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100880#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
881
882 // Calculate output rows (reuse comm_fact0 vector)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +0000883 OUTPUT_ROW_4x4_5x5(out0, comm_fact0);
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100884
885#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +0000886 OUTPUT_ROW_4x4_5x5(out1, comm_fact0);
887 OUTPUT_ROW_4x4_5x5(out2, comm_fact0);
888 OUTPUT_ROW_4x4_5x5(out3, comm_fact0);
889 OUTPUT_ROW_4x4_5x5(out4, comm_fact0);
890 OUTPUT_ROW_4x4_5x5(out5, comm_fact0);
891 OUTPUT_ROW_4x4_5x5(out6, comm_fact0);
892 OUTPUT_ROW_4x4_5x5(out7, comm_fact0);
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100893#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
894
895 // Store values across the channels
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000896#if defined(SRC_DEPTH)
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100897 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000898#else /* defined(SRC_DEPTH) */
899 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(DATA_TYPE) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
900#endif /* defined(SRC_DEPTH) */
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100901
902 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
903 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
904 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
905 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
906 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
907 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
908 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
909 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
910
911#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
912 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
913 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
914 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
915 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
916 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
917 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
918 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
919 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
920 *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
921 *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
922 *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
923 *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
924 *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
925 *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
926 *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
927 *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
928 *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
929 *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
930 *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
931 *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
932 *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
933 *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
934 *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
935 *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
936 *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
937 *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
938 *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
939 *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
940 *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
941 *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
942 *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
943 *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
944 *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
945 *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
946 *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
947 *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
948 *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
949 *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
950 *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
951 *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
952 *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
953 *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
954 *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
955 *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
956 *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
957 *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
958 *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
959 *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
960 *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
961 *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
962 *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
963 *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
964 *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
965 *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
966 *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
967 *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
968#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
969}
Georgios Pinitasc55beee2018-10-23 15:23:23 +0100970
Georgios Pinitasffb57a02018-10-29 18:01:52 +0000971#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arena149fdf32018-07-04 17:03:33 +0100972/** This OpenCL kernel computes the input transform when the output tile is 4x4, 4x1 or 1x4, the filter size 3x3, 3x1 or 1x3 and the data layout is NHWC
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100973 *
974 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
975 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
976 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
977 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arena149fdf32018-07-04 17:03:33 +0100978 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
979 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
980 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
981 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100982 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100983 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +0100984 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +0100985 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
986 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
987 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
988 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
989 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
990 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
991 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
992 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
993 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
994 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
995 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
996 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
997 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
998 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
999 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001000 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1001 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001002 */
1003__kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
1004 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001005 TENSOR3D_DECLARATION(dst),
1006 uint src_stride_w,
1007 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001008{
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001009 // Index channel
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001010 const int x = get_global_id(0);
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001011 // Index width
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001012 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001013#if defined(NUM_TILES_Y)
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001014 // Index height
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001015 const int z = get_global_id(2) % NUM_TILES_Y;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001016 // Index batch size
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001017 const int b = get_global_id(2) / NUM_TILES_Y;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001018#else // defined(NUM_TILES_Y)
1019 // Index height
Giorgio Arena2d1a8352020-10-26 15:04:08 +00001020 const int z = get_global_id(2);
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001021#endif // defined(NUM_TILES_Y)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001022
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001023#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001024 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001025#else // defined(NUM_TILES_Y)
1026 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
1027#endif // defined(NUM_TILES_Y)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001028
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001029 // Origin coordinates for the width (y) and height (z) in the input tensor
Giorgio Arena149fdf32018-07-04 17:03:33 +01001030 int4 y_coord0 = (int4)(y * OUTPUT_TILE_W) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT;
1031 int2 y_coord1 = (int2)(y * OUTPUT_TILE_W) + (int2)(4, 5) - (int2)PAD_LEFT;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001032 int4 z_coord0 = (int4)(z * OUTPUT_TILE_H) + (int4)(0, 1, 2, 3) - (int4)PAD_TOP;
1033 int2 z_coord1 = (int2)(z * OUTPUT_TILE_H) + (int2)(4, 5) - (int2)PAD_TOP;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001034
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001035 // Coordinates to use to avoid out-of-bound reads
1036 int4 y_coord_valid0 = clamp(y_coord0, (int4)0, (int4)((int)SRC_DIM_1 - 1));
1037 int2 y_coord_valid1 = clamp(y_coord1, (int2)0, (int2)((int)SRC_DIM_1 - 1));
1038 int4 z_coord_valid0 = clamp(z_coord0, (int4)0, (int4)((int)SRC_DIM_2 - 1));
1039 int2 z_coord_valid1 = clamp(z_coord1, (int2)0, (int2)((int)SRC_DIM_2 - 1));
1040
1041 // Boundary conditions
1042 int4 y_cond0 = y_coord_valid0 == y_coord0;
1043 int2 y_cond1 = y_coord_valid1 == y_coord1;
1044 int4 z_cond0 = z_coord_valid0 == z_coord0;
1045 int2 z_cond1 = z_coord_valid1 == z_coord1;
Giorgio Arena149fdf32018-07-04 17:03:33 +01001046
Giorgio Arena149fdf32018-07-04 17:03:33 +01001047#if !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001048 DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1049 DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1050 DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1051 DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1052 DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1053 DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001054
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001055 FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d0, y_cond, z_cond0.s0);
Giorgio Arena149fdf32018-07-04 17:03:33 +01001056#else // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena2d1a8352020-10-26 15:04:08 +00001057 DATA_TYPE d00 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1058 DATA_TYPE d01 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1059 DATA_TYPE d02 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1060 DATA_TYPE d03 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1061 DATA_TYPE d04 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
1062 DATA_TYPE d05 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001063
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001064 FILL_ZERO_OUT_OF_BOUND_6_NHWC_V(DATA_TYPE, d0, y_cond0.s0, z_cond);
Giorgio Arena149fdf32018-07-04 17:03:33 +01001065#endif // !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1066
Giorgio Arena149fdf32018-07-04 17:03:33 +01001067#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena049989a2021-03-22 17:02:26 +00001068 DATA_TYPE d10 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1069 DATA_TYPE d11 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1070 DATA_TYPE d12 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1071 DATA_TYPE d13 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1072 DATA_TYPE d14 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1073 DATA_TYPE d15 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1074
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001075 DATA_TYPE d20 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1076 DATA_TYPE d21 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1077 DATA_TYPE d22 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1078 DATA_TYPE d23 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1079 DATA_TYPE d24 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1080 DATA_TYPE d25 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001081
Giorgio Arena049989a2021-03-22 17:02:26 +00001082 DATA_TYPE d30 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1083 DATA_TYPE d31 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1084 DATA_TYPE d32 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1085 DATA_TYPE d33 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1086 DATA_TYPE d34 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1087 DATA_TYPE d35 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1088
1089 DATA_TYPE d40 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
1090 DATA_TYPE d41 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
1091 DATA_TYPE d42 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
1092 DATA_TYPE d43 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
1093 DATA_TYPE d44 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
1094 DATA_TYPE d45 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid1.s0 * src_stride_z);
1095
1096 DATA_TYPE d50 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
1097 DATA_TYPE d51 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
1098 DATA_TYPE d52 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
1099 DATA_TYPE d53 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
1100 DATA_TYPE d54 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s0 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
1101 DATA_TYPE d55 = *(__global DATA_TYPE *)(src_addr + y_coord_valid1.s1 * (int)src_stride_y + z_coord_valid1.s1 * src_stride_z);
1102
1103 FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d1, y_cond, z_cond0.s1);
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001104 FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d2, y_cond, z_cond0.s2);
Giorgio Arena049989a2021-03-22 17:02:26 +00001105 FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d3, y_cond, z_cond0.s3);
1106 FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d4, y_cond, z_cond1.s0);
1107 FILL_ZERO_OUT_OF_BOUND_6_NHWC_H(DATA_TYPE, d5, y_cond, z_cond1.s1);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001108
Giorgio Arena049989a2021-03-22 17:02:26 +00001109 DATA_TYPE k0, k1, k2, k3, k4, k5;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001110
Giorgio Arena049989a2021-03-22 17:02:26 +00001111 DATA_TYPE part00, part01, part02, part03, part04, part05;
1112 DATA_TYPE part10, part11, part12, part13, part14, part15;
1113 DATA_TYPE part20, part21, part22, part23, part24, part25;
1114 DATA_TYPE part30, part31, part32, part33, part34, part35;
1115 DATA_TYPE part40, part41, part42, part43, part44, part45;
1116 DATA_TYPE part50, part51, part52, part53, part54, part55;
1117
1118#define COMMON_OPS_0(i) \
1119 k0 = d2##i - 4.f * d0##i; \
1120 k1 = d3##i - 4.f * d1##i; \
1121 k2 = d4##i - 4.f * d2##i; \
1122 k3 = d5##i - 4.f * d3##i; \
1123 k4 = d3##i - d1##i; \
1124 k4 = k4 + k4; \
1125 k5 = d4##i - d2##i; \
1126 part0##i = k2 - k0; \
1127 part1##i = k2 + k1; \
1128 part2##i = k2 - k1; \
1129 part3##i = k5 + k4; \
1130 part4##i = k5 - k4; \
1131 part5##i = k3 - k1;
1132
1133#define COMMON_OPS_1(i) \
1134 k0 = part##i##2 - 4.f * part##i##0; \
1135 k1 = part##i##3 - 4.f * part##i##1; \
1136 k2 = part##i##4 - 4.f * part##i##2; \
1137 k3 = part##i##5 - 4.f * part##i##3; \
1138 k4 = part##i##3 - part##i##1; \
1139 k4 = k4 + k4; \
1140 k5 = part##i##4 - part##i##2; \
1141 DATA_TYPE out##i##0 = k2 - k0; \
1142 DATA_TYPE out##i##1 = k2 + k1; \
1143 DATA_TYPE out##i##2 = k2 - k1; \
1144 DATA_TYPE out##i##3 = k5 + k4; \
1145 DATA_TYPE out##i##4 = k5 - k4; \
1146 DATA_TYPE out##i##5 = k3 - k1;
1147
1148 COMMON_OPS_0(0);
1149 COMMON_OPS_0(1);
1150 COMMON_OPS_0(2);
1151 COMMON_OPS_0(3);
1152 COMMON_OPS_0(4);
1153 COMMON_OPS_0(5);
1154
1155 COMMON_OPS_1(0);
1156 COMMON_OPS_1(1);
1157 COMMON_OPS_1(2);
1158 COMMON_OPS_1(3);
1159 COMMON_OPS_1(4);
1160 COMMON_OPS_1(5);
1161
1162#undef COMMON_OPS_0
1163#undef COMMON_OPS_1
1164
1165#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1166
1167 DATA_TYPE k0, k1, k2, k3, k4, k5;
1168 DATA_TYPE part0, part1, part2, part3, part4, part5;
1169
1170 part0 = 4.f * d00;
1171 part1 = 4.f * d01;
1172 part2 = 4.f * d02;
1173 part3 = 4.f * d03;
1174 part4 = 4.f * d04;
1175 part5 = 4.f * d05;
1176
1177 k0 = part2 - 4.f * part0;
1178 k1 = part3 - 4.f * part1;
1179 k2 = part4 - 4.f * part2;
1180 k3 = part5 - 4.f * part3;
1181 k4 = part3 - part1;
1182 k4 = k4 + k4;
1183 k5 = part4 - part2;
1184
1185 DATA_TYPE out00 = k2 - k0;
1186 DATA_TYPE out01 = k2 + k1;
1187 DATA_TYPE out02 = k2 - k1;
1188 DATA_TYPE out03 = k5 + k4;
1189 DATA_TYPE out04 = k5 - k4;
1190 DATA_TYPE out05 = k3 - k1;
1191
Giorgio Arena149fdf32018-07-04 17:03:33 +01001192#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1193
1194 // Compute destination address
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001195#if defined(NUM_TILES_Y)
1196 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w);
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001197#else // defined(NUM_TILES_Y)
Giorgio Arena2d1a8352020-10-26 15:04:08 +00001198 __global DATA_TYPE *dst_addr = (__global DATA_TYPE *)(dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y);
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001199#endif // defined(NUM_TILES_Y)
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001200
1201 uint dst_plane_stride = dst_stride_z / sizeof(DATA_TYPE);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001202
Giorgio Arena049989a2021-03-22 17:02:26 +00001203 *((__global DATA_TYPE *)dst_addr) = out00;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001204 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001205 *((__global DATA_TYPE *)dst_addr) = out01;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001206 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001207 *((__global DATA_TYPE *)dst_addr) = out02;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001208 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001209 *((__global DATA_TYPE *)dst_addr) = out03;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001210 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001211 *((__global DATA_TYPE *)dst_addr) = out04;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001212 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001213 *((__global DATA_TYPE *)dst_addr) = out05;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001214 dst_addr += dst_plane_stride;
1215
Giorgio Arena149fdf32018-07-04 17:03:33 +01001216#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001217 *((__global DATA_TYPE *)dst_addr) = out10;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001218 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001219 *((__global DATA_TYPE *)dst_addr) = out11;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001220 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001221 *((__global DATA_TYPE *)dst_addr) = out12;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001222 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001223 *((__global DATA_TYPE *)dst_addr) = out13;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001224 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001225 *((__global DATA_TYPE *)dst_addr) = out14;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001226 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001227 *((__global DATA_TYPE *)dst_addr) = out15;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001228 dst_addr += dst_plane_stride;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001229
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001230 *((__global DATA_TYPE *)dst_addr) = out20;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001231 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001232 *((__global DATA_TYPE *)dst_addr) = out21;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001233 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001234 *((__global DATA_TYPE *)dst_addr) = out22;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001235 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001236 *((__global DATA_TYPE *)dst_addr) = out23;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001237 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001238 *((__global DATA_TYPE *)dst_addr) = out24;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001239 dst_addr += dst_plane_stride;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001240 *((__global DATA_TYPE *)dst_addr) = out25;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001241 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001242
1243 *((__global DATA_TYPE *)dst_addr) = out30;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001244 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001245 *((__global DATA_TYPE *)dst_addr) = out31;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001246 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001247 *((__global DATA_TYPE *)dst_addr) = out32;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001248 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001249 *((__global DATA_TYPE *)dst_addr) = out33;
1250 dst_addr += dst_plane_stride;
1251 *((__global DATA_TYPE *)dst_addr) = out34;
1252 dst_addr += dst_plane_stride;
1253 *((__global DATA_TYPE *)dst_addr) = out35;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001254 dst_addr += dst_plane_stride;
1255
Giorgio Arena049989a2021-03-22 17:02:26 +00001256 *((__global DATA_TYPE *)dst_addr) = out40;
1257 dst_addr += dst_plane_stride;
1258 *((__global DATA_TYPE *)dst_addr) = out41;
1259 dst_addr += dst_plane_stride;
1260 *((__global DATA_TYPE *)dst_addr) = out42;
1261 dst_addr += dst_plane_stride;
1262 *((__global DATA_TYPE *)dst_addr) = out43;
1263 dst_addr += dst_plane_stride;
1264 *((__global DATA_TYPE *)dst_addr) = out44;
1265 dst_addr += dst_plane_stride;
1266 *((__global DATA_TYPE *)dst_addr) = out45;
1267 dst_addr += dst_plane_stride;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001268
Giorgio Arena049989a2021-03-22 17:02:26 +00001269 *((__global DATA_TYPE *)dst_addr) = out50;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001270 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001271 *((__global DATA_TYPE *)dst_addr) = out51;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001272 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001273 *((__global DATA_TYPE *)dst_addr) = out52;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001274 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001275 *((__global DATA_TYPE *)dst_addr) = out53;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001276 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001277 *((__global DATA_TYPE *)dst_addr) = out54;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001278 dst_addr += dst_plane_stride;
Giorgio Arena049989a2021-03-22 17:02:26 +00001279 *((__global DATA_TYPE *)dst_addr) = out55;
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001280 dst_addr += dst_plane_stride;
Giorgio Arena149fdf32018-07-04 17:03:33 +01001281#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001282}
1283
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001284/** This OpenCL kernel computes the input transform when the kernel size is 5x5/5x1 or 1x5 and the output tile is 4x4/4x1 or 1x4 when the data layout is NHWC
Giorgio Arena149fdf32018-07-04 17:03:33 +01001285 *
1286 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1287 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001288 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1289 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001290 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
Giorgio Arena149fdf32018-07-04 17:03:33 +01001291 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001292 * @note If this kernel is used to perform Winograd input transform 5x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1293 * @note If this kernel is used to perform Winograd input transform 1x5, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001294 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arena149fdf32018-07-04 17:03:33 +01001295 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001296 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arena149fdf32018-07-04 17:03:33 +01001297 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1298 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1299 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1300 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1301 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1302 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1303 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1304 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1305 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1306 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1307 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1308 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1309 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1310 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1311 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001312 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1313 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001314 */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001315__kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
Giorgio Arena149fdf32018-07-04 17:03:33 +01001316 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001317 TENSOR3D_DECLARATION(dst),
1318 uint src_stride_w,
1319 uint dst_stride_w)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001320{
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001321 const int x = get_global_id(0);
1322 const int y = get_global_id(1);
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001323#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001324 const int z = get_global_id(2) % NUM_TILES_Y;
1325 const int b = get_global_id(2) / NUM_TILES_Y;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001326#else // defined(NUM_TILES_Y)
Giorgio Arena2d1a8352020-10-26 15:04:08 +00001327 const int z = get_global_id(2);
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001328#endif // defined(NUM_TILES_Y)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001329
1330 // Compute input address
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001331#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001332 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001333#else // defined(NUM_TILES_Y)
Giorgio Arena2d1a8352020-10-26 15:04:08 +00001334 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001335#endif // defined(NUM_TILES_Y)
1336
1337 // Origin coordinates for the width (y) and height (z) in the input tensor
1338 int8 y_coord0 = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
1339 int8 z_coord0 = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
1340
1341 // Coordinates to use to avoid out-of-bound reads
1342 int8 y_coord_valid0 = clamp(y_coord0, (int8)0, (int8)((int)SRC_DIM_1 - 1));
1343 int8 z_coord_valid0 = clamp(z_coord0, (int8)0, (int8)((int)SRC_DIM_2 - 1));
1344
1345 // Boundary conditions
1346 int8 y_cond0 = y_coord_valid0 == y_coord0;
1347 int8 z_cond0 = z_coord_valid0 == z_coord0;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001348
1349#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001350 // Load the input tile
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001351 VEC_DATA_TYPE(DATA_TYPE, 8)
1352 in_row0;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001353 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1354 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1355 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1356 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1357 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1358 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1359 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1360 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1361
1362 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001363
1364 // Calculate common factors for intermediate tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001365 VEC_DATA_TYPE(DATA_TYPE, 8)
1366 comm_fact0 = 0.0f;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001367
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001368 VEC_DATA_TYPE(DATA_TYPE, 8)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001369 out0 = in_row0;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001370
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001371 OUTPUT_ROW_4x4_5x5(out0, comm_fact0);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001372
1373#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001374
1375 // Load the input tile
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001376 VEC_DATA_TYPE(DATA_TYPE, 8)
1377 in_row0;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001378 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1379 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1380 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1381 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1382 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1383 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1384 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1385 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1386
1387 FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(DATA_TYPE, in_row0.s, y_cond0.s0, z_cond);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001388
1389 // Calculate common factors for intermediate tensor
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001390 VEC_DATA_TYPE(DATA_TYPE, 8)
1391 comm_fact0 = 0.0f;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001392
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001393 VEC_DATA_TYPE(DATA_TYPE, 8)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001394 out0 = in_row0;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001395
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001396 OUTPUT_ROW_4x4_5x5(out0, comm_fact0);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001397#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001398 VEC_DATA_TYPE(DATA_TYPE, 8)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001399 out0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, out7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001400
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001401 // Row0
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001402 out0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1403 out0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1404 out0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1405 out0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1406 out0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1407 out0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1408 out0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1409 out0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001410
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001411 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, out0.s, y_cond, z_cond0.s0);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001412
1413 // Row1
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001414 in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1415 in_row1.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1416 in_row1.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1417 in_row1.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1418 in_row1.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1419 in_row1.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1420 in_row1.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1421 in_row1.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001422
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001423 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row1.s, y_cond, z_cond0.s1);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001424
1425 // Row2
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001426 in_row2.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1427 in_row2.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1428 in_row2.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1429 in_row2.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1430 in_row2.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1431 in_row2.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1432 in_row2.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1433 in_row2.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001434
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001435 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row2.s, y_cond, z_cond0.s2);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001436
1437 // Row3
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001438 in_row3.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1439 in_row3.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1440 in_row3.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1441 in_row3.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1442 in_row3.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1443 in_row3.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1444 in_row3.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1445 in_row3.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001446
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001447 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row3.s, y_cond, z_cond0.s3);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001448
1449 // Row4
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001450 in_row4.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1451 in_row4.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1452 in_row4.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1453 in_row4.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1454 in_row4.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1455 in_row4.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1456 in_row4.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1457 in_row4.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001458
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001459 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row4.s, y_cond, z_cond0.s4);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001460
1461 // Row5
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001462 in_row5.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1463 in_row5.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1464 in_row5.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1465 in_row5.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1466 in_row5.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1467 in_row5.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1468 in_row5.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1469 in_row5.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001470
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001471 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row5.s, y_cond, z_cond0.s5);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001472
1473 // Row6
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001474 in_row6.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1475 in_row6.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1476 in_row6.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1477 in_row6.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1478 in_row6.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1479 in_row6.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1480 in_row6.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1481 in_row6.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001482
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001483 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001484
1485 // Row7
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001486 out7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1487 out7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1488 out7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1489 out7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1490 out7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1491 out7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1492 out7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1493 out7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001494
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001495 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, out7.s, y_cond, z_cond0.s7);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001496
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001497 VEC_DATA_TYPE(DATA_TYPE, 8)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001498 out1, out2, out3, out4, out5, out6;
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001499 VEC_DATA_TYPE(DATA_TYPE, 8)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001500 comm_fact0, comm_fact1;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001501
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001502 BT_MULTIPLY_4x4_5x5(out, in_row, comm_fact0, comm_fact1, DATA_TYPE);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001503
1504 // Calculate output rows (reuse comm_fact0 vector)
Aleksandr Nikolaev2ca5b082021-03-18 14:03:48 +00001505 OUTPUT_ROW_4x4_5x5(out0, comm_fact0);
1506 OUTPUT_ROW_4x4_5x5(out1, comm_fact0);
1507 OUTPUT_ROW_4x4_5x5(out2, comm_fact0);
1508 OUTPUT_ROW_4x4_5x5(out3, comm_fact0);
1509 OUTPUT_ROW_4x4_5x5(out4, comm_fact0);
1510 OUTPUT_ROW_4x4_5x5(out5, comm_fact0);
1511 OUTPUT_ROW_4x4_5x5(out6, comm_fact0);
1512 OUTPUT_ROW_4x4_5x5(out7, comm_fact0);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001513#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1514
1515 // Store values across the channels
1516#if defined(NUM_TILES_Y)
1517 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
1518#else /* NUM_TILES_Y */
1519 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
1520#endif /* NUM_TILES_Y */
1521
1522 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1523 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1524 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1525 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1526 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1527 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1528 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1529 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
1530
1531#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1532 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1533 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1534 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1535 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1536 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1537 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1538 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1539 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1540 *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1541 *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1542 *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1543 *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1544 *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1545 *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1546 *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1547 *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1548 *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1549 *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1550 *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1551 *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1552 *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1553 *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1554 *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1555 *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1556 *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1557 *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1558 *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1559 *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1560 *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1561 *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1562 *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1563 *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1564 *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1565 *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1566 *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1567 *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1568 *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1569 *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1570 *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1571 *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1572 *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1573 *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1574 *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1575 *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1576 *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1577 *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1578 *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1579 *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1580 *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1581 *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1582 *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1583 *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1584 *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1585 *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1586 *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1587 *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
1588#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1589}
1590
1591/** This OpenCL kernel computes the input transform when the kernel size is 7x7/7x1/1x7 and the output tile is 2x2/7x1/1x7 when the data layout is NHWC
1592 *
1593 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=7).
1594 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1595 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1596 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
1597 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1598 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1599 * @note If this kernel is used to perform Winograd input transform 7x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1600 * @note If this kernel is used to perform Winograd input transform 1x7, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
1601 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
1602 *
1603 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1604 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1605 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1606 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1607 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1608 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1609 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1610 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1611 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1612 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1613 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1614 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1615 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1616 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1617 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1618 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1619 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1620 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
1621 */
1622__kernel void winograd_input_transform_2x2_7x7_stepz1_nhwc(
1623 TENSOR3D_DECLARATION(src),
1624 TENSOR3D_DECLARATION(dst),
1625 uint src_stride_w,
1626 uint dst_stride_w)
1627{
1628 const int x = get_global_id(0);
1629 const int y = get_global_id(1);
1630#if defined(NUM_TILES_Y)
1631 const int z = get_global_id(2) % NUM_TILES_Y;
1632 const int b = get_global_id(2) / NUM_TILES_Y;
1633#else /* defined(NUM_TILES_Y) */
1634 const int z = get_global_id(2);
1635#endif /* defined(NUM_TILES_Y) */
1636
1637 // Compute input address
1638#if defined(NUM_TILES_Y)
1639 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + b * src_stride_w;
1640#else /* defined(NUM_TILES_Y) */
1641 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(DATA_TYPE);
1642#endif /* defined(NUM_TILES_Y) */
1643
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001644 // Origin coordinates for the width (y) and height (z) in the input tensor
1645 int8 y_coord0 = (int8)(y * OUTPUT_TILE_W) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
1646 int8 z_coord0 = (int8)(z * OUTPUT_TILE_H) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_TOP;
1647
1648 // Coordinates to use to avoid out-of-bound reads
1649 int8 y_coord_valid0 = clamp(y_coord0, (int8)0, (int8)((int)SRC_DIM_1 - 1));
1650 int8 z_coord_valid0 = clamp(z_coord0, (int8)0, (int8)((int)SRC_DIM_2 - 1));
1651
1652 // Boundary conditions
1653 int8 y_cond0 = y_coord_valid0 == y_coord0;
1654 int8 z_cond0 = z_coord_valid0 == z_coord0;
1655
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001656#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1657
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001658 // Load the input tile
1659 VEC_DATA_TYPE(DATA_TYPE, 8)
1660 in_row0;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001661 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1662 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1663 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1664 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1665 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1666 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1667 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1668 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1669
1670 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001671
1672 VEC_DATA_TYPE(DATA_TYPE, 8)
1673 out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
1674
1675 VEC_DATA_TYPE(DATA_TYPE, 8)
1676 tmp0 = ((VEC_DATA_TYPE(DATA_TYPE, 8)) - 36.0f) * in_row0;
1677
1678 VEC_DATA_TYPE(DATA_TYPE, 8)
1679 comm_fact0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
1680
1681 OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
1682
1683#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001684 // Load the input tile
1685 VEC_DATA_TYPE(DATA_TYPE, 8)
1686 in_row0;
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001687 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1688 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1689 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1690 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1691 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1692 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1693 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1694 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1695
1696 FILL_ZERO_OUT_OF_BOUND_8_NHWC_V(DATA_TYPE, in_row0.s, y_cond0.s0, z_cond);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001697
1698 // Calculate common factors for intermediate tensor
1699 VEC_DATA_TYPE(DATA_TYPE, 8)
1700 tmp0 = ((VEC_DATA_TYPE(DATA_TYPE, 8)) - 36.0f) * in_row0;
1701
1702 VEC_DATA_TYPE(DATA_TYPE, 8)
1703 out0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
1704
1705 VEC_DATA_TYPE(DATA_TYPE, 8)
1706 comm_fact0 = (VEC_DATA_TYPE(DATA_TYPE, 8))0.0f;
1707
1708 OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
1709#else // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1710 VEC_DATA_TYPE(DATA_TYPE, 8)
1711 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
1712
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001713 // Row0
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001714 in_row0.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1715 in_row0.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1716 in_row0.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1717 in_row0.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1718 in_row0.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1719 in_row0.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1720 in_row0.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
1721 in_row0.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s0 * src_stride_z);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001722
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001723 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row0.s, y_cond, z_cond0.s0);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001724
1725 // Row1
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001726 in_row1.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1727 in_row1.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1728 in_row1.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1729 in_row1.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1730 in_row1.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1731 in_row1.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1732 in_row1.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
1733 in_row1.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s1 * src_stride_z);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001734
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001735 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row1.s, y_cond, z_cond0.s1);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001736
1737 // Row2
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001738 in_row2.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1739 in_row2.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1740 in_row2.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1741 in_row2.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1742 in_row2.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1743 in_row2.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1744 in_row2.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
1745 in_row2.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s2 * src_stride_z);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001746
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001747 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row2.s, y_cond, z_cond0.s2);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001748
1749 // Row3
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001750 in_row3.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1751 in_row3.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1752 in_row3.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1753 in_row3.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1754 in_row3.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1755 in_row3.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1756 in_row3.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
1757 in_row3.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s3 * src_stride_z);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001758
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001759 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row3.s, y_cond, z_cond0.s3);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001760
1761 // Row4
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001762 in_row4.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1763 in_row4.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1764 in_row4.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1765 in_row4.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1766 in_row4.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1767 in_row4.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1768 in_row4.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
1769 in_row4.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s4 * src_stride_z);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001770
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001771 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row4.s, y_cond, z_cond0.s4);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001772
1773 // Row5
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001774 in_row5.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1775 in_row5.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1776 in_row5.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1777 in_row5.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1778 in_row5.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1779 in_row5.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1780 in_row5.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
1781 in_row5.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s5 * src_stride_z);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001782
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001783 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row5.s, y_cond, z_cond0.s5);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001784
1785 // Row6
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001786 in_row6.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1787 in_row6.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1788 in_row6.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1789 in_row6.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1790 in_row6.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1791 in_row6.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1792 in_row6.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
1793 in_row6.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s6 * src_stride_z);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001794
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001795 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row6.s, y_cond, z_cond0.s6);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001796
1797 // Row7
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001798 in_row7.s0 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s0 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1799 in_row7.s1 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s1 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1800 in_row7.s2 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s2 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1801 in_row7.s3 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s3 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1802 in_row7.s4 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s4 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1803 in_row7.s5 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s5 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1804 in_row7.s6 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s6 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
1805 in_row7.s7 = *(__global DATA_TYPE *)(src_addr + y_coord_valid0.s7 * (int)src_stride_y + z_coord_valid0.s7 * src_stride_z);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001806
Gian Marco Iodicebc6c3742020-10-19 12:49:44 +01001807 FILL_ZERO_OUT_OF_BOUND_8_NHWC_H(DATA_TYPE, in_row7.s, y_cond, z_cond0.s7);
Michele Di Giorgiof955d512019-02-27 14:26:51 +00001808
1809 VEC_DATA_TYPE(DATA_TYPE, 8)
1810 comm_fact0 = (DATA_TYPE)36.0f * in_row2 - (DATA_TYPE)13.0f * in_row4 + in_row6;
1811 VEC_DATA_TYPE(DATA_TYPE, 8)
1812 comm_fact1 = (DATA_TYPE)36.0f * in_row1 - (DATA_TYPE)13.0f * in_row3 + in_row5;
1813 VEC_DATA_TYPE(DATA_TYPE, 8)
1814 comm_fact2 = (DATA_TYPE)9.0f * in_row2 - (DATA_TYPE)10.0f * in_row4 + in_row6;
1815 VEC_DATA_TYPE(DATA_TYPE, 8)
1816 comm_fact3 = (DATA_TYPE)18.0f * in_row1 - (DATA_TYPE)20.0f * in_row3 + (DATA_TYPE)2.0f * in_row5;
1817 VEC_DATA_TYPE(DATA_TYPE, 8)
1818 comm_fact4 = (DATA_TYPE)4.0f * in_row2 - (DATA_TYPE)5.0f * in_row4 + in_row6;
1819 VEC_DATA_TYPE(DATA_TYPE, 8)
1820 comm_fact5 = (DATA_TYPE)12.0f * in_row1 - (DATA_TYPE)15.0f * in_row3 + (DATA_TYPE)3.0f * in_row5;
1821
1822 // Calculate intermediate tensors
1823 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp0 = -(DATA_TYPE)36.0f * in_row0 + (DATA_TYPE)49.0f * in_row2 - (DATA_TYPE)14.0f * in_row4 + in_row6;
1824 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp1 = comm_fact0 - comm_fact1;
1825 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp2 = comm_fact0 + comm_fact1;
1826 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp3 = comm_fact2 - comm_fact3;
1827 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp4 = comm_fact2 + comm_fact3;
1828 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp5 = comm_fact4 - comm_fact5;
1829 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp6 = comm_fact4 + comm_fact5;
1830 const VEC_DATA_TYPE(DATA_TYPE, 8) tmp7 = -(DATA_TYPE)36.0f * in_row1 + (DATA_TYPE)49.0f * in_row3 - (DATA_TYPE)14.0f * in_row5 + in_row7;
1831
1832 VEC_DATA_TYPE(DATA_TYPE, 8)
1833 out0, out1, out2, out3, out4, out5, out6, out7;
1834
1835 OUTPUT_ROW_2x2_7x7(out0, tmp0, comm_fact0);
1836 OUTPUT_ROW_2x2_7x7(out1, tmp1, comm_fact0);
1837 OUTPUT_ROW_2x2_7x7(out2, tmp2, comm_fact0);
1838 OUTPUT_ROW_2x2_7x7(out3, tmp3, comm_fact0);
1839 OUTPUT_ROW_2x2_7x7(out4, tmp4, comm_fact0);
1840 OUTPUT_ROW_2x2_7x7(out5, tmp5, comm_fact0);
1841 OUTPUT_ROW_2x2_7x7(out6, tmp6, comm_fact0);
1842 OUTPUT_ROW_2x2_7x7(out7, tmp7, comm_fact0);
1843
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001844#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001845
1846 // Store values across the channels
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001847#if defined(NUM_TILES_Y)
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001848 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y + b * dst_stride_w;
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001849#else /* NUM_TILES_Y */
1850 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(DATA_TYPE) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
1851#endif /* NUM_TILES_Y */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001852
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001853 *((__global DATA_TYPE *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1854 *((__global DATA_TYPE *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1855 *((__global DATA_TYPE *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1856 *((__global DATA_TYPE *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1857 *((__global DATA_TYPE *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1858 *((__global DATA_TYPE *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1859 *((__global DATA_TYPE *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1860 *((__global DATA_TYPE *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001861
1862#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001863 *((__global DATA_TYPE *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1864 *((__global DATA_TYPE *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1865 *((__global DATA_TYPE *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1866 *((__global DATA_TYPE *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1867 *((__global DATA_TYPE *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1868 *((__global DATA_TYPE *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1869 *((__global DATA_TYPE *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1870 *((__global DATA_TYPE *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1871 *((__global DATA_TYPE *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1872 *((__global DATA_TYPE *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1873 *((__global DATA_TYPE *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1874 *((__global DATA_TYPE *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1875 *((__global DATA_TYPE *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1876 *((__global DATA_TYPE *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1877 *((__global DATA_TYPE *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1878 *((__global DATA_TYPE *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1879 *((__global DATA_TYPE *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1880 *((__global DATA_TYPE *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1881 *((__global DATA_TYPE *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1882 *((__global DATA_TYPE *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1883 *((__global DATA_TYPE *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1884 *((__global DATA_TYPE *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1885 *((__global DATA_TYPE *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1886 *((__global DATA_TYPE *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1887 *((__global DATA_TYPE *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1888 *((__global DATA_TYPE *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1889 *((__global DATA_TYPE *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1890 *((__global DATA_TYPE *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1891 *((__global DATA_TYPE *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1892 *((__global DATA_TYPE *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1893 *((__global DATA_TYPE *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1894 *((__global DATA_TYPE *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1895 *((__global DATA_TYPE *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1896 *((__global DATA_TYPE *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1897 *((__global DATA_TYPE *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1898 *((__global DATA_TYPE *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1899 *((__global DATA_TYPE *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1900 *((__global DATA_TYPE *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1901 *((__global DATA_TYPE *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1902 *((__global DATA_TYPE *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1903 *((__global DATA_TYPE *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1904 *((__global DATA_TYPE *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1905 *((__global DATA_TYPE *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1906 *((__global DATA_TYPE *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1907 *((__global DATA_TYPE *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1908 *((__global DATA_TYPE *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1909 *((__global DATA_TYPE *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1910 *((__global DATA_TYPE *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1911 *((__global DATA_TYPE *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1912 *((__global DATA_TYPE *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1913 *((__global DATA_TYPE *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1914 *((__global DATA_TYPE *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1915 *((__global DATA_TYPE *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1916 *((__global DATA_TYPE *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1917 *((__global DATA_TYPE *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1918 *((__global DATA_TYPE *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01001919#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena149fdf32018-07-04 17:03:33 +01001920}
Georgios Pinitasffb57a02018-10-29 18:01:52 +00001921#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001922
1923#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1924/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
1925 *
1926 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1927 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1928 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1929 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1930 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001931 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001932 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001933 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001934 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1935 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1936 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1937 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1938 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1939 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1940 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1941 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1942 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1943 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1944 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1945 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1946 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1947 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1948 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001949 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
1950 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001951 */
1952__kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
1953 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001954 TENSOR3D_DECLARATION(dst),
1955 uint src_stride_w,
1956 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001957{
1958 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
1959 src_stride_x,
1960 src_step_x,
1961 src_stride_y,
1962 src_step_y,
1963 src_stride_z,
1964 src_step_z,
1965 src_offset_first_element_in_bytes,
1966 dst_ptr,
1967 dst_stride_x,
1968 dst_step_x,
1969 dst_stride_y,
1970 dst_step_y,
1971 dst_stride_z,
1972 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01001973 dst_offset_first_element_in_bytes,
1974 src_stride_w,
1975 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001976}
1977
1978/** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
1979 *
1980 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1981 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
1982 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1983 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
1984 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001985 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001986 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01001987 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01001988 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1989 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1990 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1991 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1992 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1993 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1994 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1995 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1996 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1997 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1998 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1999 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2000 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2001 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2002 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002003 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2004 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002005 */
2006__kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
2007 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002008 TENSOR3D_DECLARATION(dst),
2009 uint src_stride_w,
2010 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002011{
2012 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
2013 src_stride_x,
2014 src_step_x,
2015 src_stride_y,
2016 src_step_y,
2017 src_stride_z,
2018 src_step_z,
2019 src_offset_first_element_in_bytes,
2020 dst_ptr,
2021 dst_stride_x,
2022 dst_step_x,
2023 dst_stride_y,
2024 dst_step_y,
2025 dst_stride_z,
2026 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002027 dst_offset_first_element_in_bytes,
2028 src_stride_w,
2029 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002030}
2031
2032/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
2033 *
2034 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2035 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2036 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
2037 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2038 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002039 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002040 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002041 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002042 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2043 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2044 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2045 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2046 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2047 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2048 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2049 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2050 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2051 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2052 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2053 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2054 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2055 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2056 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002057 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2058 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002059 */
2060__kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
2061 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002062 TENSOR3D_DECLARATION(dst),
2063 uint src_stride_w,
2064 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002065{
2066 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
2067 src_stride_x,
2068 src_step_x,
2069 src_stride_y,
2070 src_step_y,
2071 src_stride_z,
2072 src_step_z,
2073 src_offset_first_element_in_bytes,
2074 dst_ptr,
2075 dst_stride_x,
2076 dst_step_x,
2077 dst_stride_y,
2078 dst_step_y,
2079 dst_stride_z,
2080 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002081 dst_offset_first_element_in_bytes,
2082 src_stride_w,
2083 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002084}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002085
2086/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 when the data layout is NCHW
2087 *
2088 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2089 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2090 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
2091 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2092 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002093 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002094 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002095 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002096 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2097 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2098 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2099 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2100 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2101 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2102 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2103 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2104 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2105 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2106 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2107 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2108 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2109 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2110 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002111 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2112 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002113 */
2114__kernel void winograd_input_transform_4x1_5x1_stepz1_nchw(
2115 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002116 TENSOR3D_DECLARATION(dst),
2117 uint src_stride_w,
2118 uint dst_stride_w)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002119{
2120 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
2121 src_stride_x,
2122 src_step_x,
2123 src_stride_y,
2124 src_step_y,
2125 src_stride_z,
2126 src_step_z,
2127 src_offset_first_element_in_bytes,
2128 dst_ptr,
2129 dst_stride_x,
2130 dst_step_x,
2131 dst_stride_y,
2132 dst_step_y,
2133 dst_stride_z,
2134 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002135 dst_offset_first_element_in_bytes,
2136 src_stride_w,
2137 dst_stride_w);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002138}
2139
Georgios Pinitasffb57a02018-10-29 18:01:52 +00002140#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002141/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1 for data layout NHWC
2142 *
2143 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2144 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2145 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2146 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2147 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
2148 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2149 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002150 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002151 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002152 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002153 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2154 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2155 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2156 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2157 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2158 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2159 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2160 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2161 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2162 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2163 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2164 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2165 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2166 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2167 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002168 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2169 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002170 */
2171__kernel void winograd_input_transform_4x1_3x1_stepz1_nhwc(
2172 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002173 TENSOR3D_DECLARATION(dst),
2174 uint src_stride_w,
2175 uint dst_stride_w)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002176{
2177 winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
2178 src_stride_x,
2179 src_step_x,
2180 src_stride_y,
2181 src_step_y,
2182 src_stride_z,
2183 src_step_z,
2184 src_offset_first_element_in_bytes,
2185 dst_ptr,
2186 dst_stride_x,
2187 dst_step_x,
2188 dst_stride_y,
2189 dst_step_y,
2190 dst_stride_z,
2191 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002192 dst_offset_first_element_in_bytes,
2193 src_stride_w,
2194 dst_stride_w);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002195}
2196
2197/** This OpenCL kernel computes the input transform when the kernel size is 5x1 and the output tile is 4x1 for data layout NHWC
2198 *
2199 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2200 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2201 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2202 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2203 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
2204 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2205 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002206 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002207 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002208 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002209 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2210 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2211 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2212 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2213 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2214 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2215 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2216 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2217 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2218 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2219 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2220 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2221 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2222 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2223 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002224 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2225 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002226 */
2227__kernel void winograd_input_transform_4x1_5x1_stepz1_nhwc(
2228 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002229 TENSOR3D_DECLARATION(dst),
2230 uint src_stride_w,
2231 uint dst_stride_w)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002232{
2233 winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
2234 src_stride_x,
2235 src_step_x,
2236 src_stride_y,
2237 src_step_y,
2238 src_stride_z,
2239 src_step_z,
2240 src_offset_first_element_in_bytes,
2241 dst_ptr,
2242 dst_stride_x,
2243 dst_step_x,
2244 dst_stride_y,
2245 dst_step_y,
2246 dst_stride_z,
2247 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002248 dst_offset_first_element_in_bytes,
2249 src_stride_w,
2250 dst_stride_w);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002251}
Michele Di Giorgiof955d512019-02-27 14:26:51 +00002252
2253/** This OpenCL kernel computes the input transform when the kernel size is 7x1 and the output tile is 2x1 for data layout NHWC
2254 *
2255 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=7).
2256 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2257 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2258 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2259 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=7
2260 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2261 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
2262 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
2263 *
2264 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
2265 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2266 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2267 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2268 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2269 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2270 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2271 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2272 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2273 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2275 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2277 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2278 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2280 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2281 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
2282 */
2283__kernel void winograd_input_transform_2x1_7x1_stepz1_nhwc(
2284 TENSOR3D_DECLARATION(src),
2285 TENSOR3D_DECLARATION(dst),
2286 uint src_stride_w,
2287 uint dst_stride_w)
2288{
2289 winograd_input_transform_2x2_7x7_stepz1_nhwc(src_ptr,
2290 src_stride_x,
2291 src_step_x,
2292 src_stride_y,
2293 src_step_y,
2294 src_stride_z,
2295 src_step_z,
2296 src_offset_first_element_in_bytes,
2297 dst_ptr,
2298 dst_stride_x,
2299 dst_step_x,
2300 dst_stride_y,
2301 dst_step_y,
2302 dst_stride_z,
2303 dst_step_z,
2304 dst_offset_first_element_in_bytes,
2305 src_stride_w,
2306 dst_stride_w);
2307}
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002308#endif // defined(NUM_TILES_Y) && defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002309#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
2310
2311#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
2312/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
2313 *
2314 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2315 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2316 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2317 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2318 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002319 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002320 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002321 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002322 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2323 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2324 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2325 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2326 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2327 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2328 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2329 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2330 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2331 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2332 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2333 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2334 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2335 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2336 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002337 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2338 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002339 */
2340__kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
2341 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002342 TENSOR3D_DECLARATION(dst),
2343 uint src_stride_w,
2344 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002345{
2346 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
2347 src_stride_x,
2348 src_step_x,
2349 src_stride_y,
2350 src_step_y,
2351 src_stride_z,
2352 src_step_z,
2353 src_offset_first_element_in_bytes,
2354 dst_ptr,
2355 dst_stride_x,
2356 dst_step_x,
2357 dst_stride_y,
2358 dst_step_y,
2359 dst_stride_z,
2360 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002361 dst_offset_first_element_in_bytes,
2362 src_stride_w,
2363 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002364}
2365
2366/** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
2367 *
2368 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2369 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2370 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2371 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2372 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002373 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002374 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002375 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002376 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2377 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2378 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2379 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2380 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2381 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2382 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2383 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2384 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2385 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2386 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2387 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2388 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2389 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2390 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002391 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2392 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002393 */
2394__kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
2395 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002396 TENSOR3D_DECLARATION(dst),
2397 uint src_stride_w,
2398 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002399{
2400 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
2401 src_stride_x,
2402 src_step_x,
2403 src_stride_y,
2404 src_step_y,
2405 src_stride_z,
2406 src_step_z,
2407 src_offset_first_element_in_bytes,
2408 dst_ptr,
2409 dst_stride_x,
2410 dst_step_x,
2411 dst_stride_y,
2412 dst_step_y,
2413 dst_stride_z,
2414 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002415 dst_offset_first_element_in_bytes,
2416 src_stride_w,
2417 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002418}
2419
2420/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
2421 *
2422 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2423 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2424 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2425 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2426 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002427 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002428 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002429 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002430 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2431 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2432 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2433 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2434 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2435 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2436 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2437 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2438 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2439 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2440 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2441 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2442 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2443 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2444 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002445 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2446 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002447 */
2448__kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
2449 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002450 TENSOR3D_DECLARATION(dst),
2451 uint src_stride_w,
2452 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002453{
2454 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
2455 src_stride_x,
2456 src_step_x,
2457 src_stride_y,
2458 src_step_y,
2459 src_stride_z,
2460 src_step_z,
2461 src_offset_first_element_in_bytes,
2462 dst_ptr,
2463 dst_stride_x,
2464 dst_step_x,
2465 dst_stride_y,
2466 dst_step_y,
2467 dst_stride_z,
2468 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002469 dst_offset_first_element_in_bytes,
2470 src_stride_w,
2471 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002472}
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002473
2474/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4
2475 *
2476 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2477 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2478 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2479 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2480 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002481 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002482 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002483 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002484 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2485 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2486 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2487 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2488 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2489 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2490 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2491 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2492 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2493 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2494 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2495 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2496 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2497 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2498 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002499 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2500 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002501 */
2502__kernel void winograd_input_transform_1x4_1x5_stepz1_nchw(
2503 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002504 TENSOR3D_DECLARATION(dst),
2505 uint src_stride_w,
2506 uint dst_stride_w)
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002507{
2508 winograd_input_transform_4x4_5x5_stepz1_nchw(src_ptr,
2509 src_stride_x,
2510 src_step_x,
2511 src_stride_y,
2512 src_step_y,
2513 src_stride_z,
2514 src_step_z,
2515 src_offset_first_element_in_bytes,
2516 dst_ptr,
2517 dst_stride_x,
2518 dst_step_x,
2519 dst_stride_y,
2520 dst_step_y,
2521 dst_stride_z,
2522 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002523 dst_offset_first_element_in_bytes,
2524 src_stride_w,
2525 dst_stride_w);
Gian Marco Iodice876be2a2018-07-03 12:22:09 +01002526}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002527
Georgios Pinitasffb57a02018-10-29 18:01:52 +00002528#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002529/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4 for data layout NHWC
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002530 *
2531 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002532 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2533 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002534 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002535 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002536 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002537 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002538 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002539 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002540 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002541 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2542 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2543 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2544 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2545 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2546 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2547 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2548 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2549 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2550 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2551 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2552 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2553 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2554 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2555 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002556 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2557 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002558 */
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002559__kernel void winograd_input_transform_1x4_1x3_stepz1_nhwc(
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002560 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002561 TENSOR3D_DECLARATION(dst),
2562 uint src_stride_w,
2563 uint dst_stride_w)
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002564{
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002565 winograd_input_transform_4x4_3x3_stepz1_nhwc(src_ptr,
2566 src_stride_x,
2567 src_step_x,
2568 src_stride_y,
2569 src_step_y,
2570 src_stride_z,
2571 src_step_z,
2572 src_offset_first_element_in_bytes,
2573 dst_ptr,
2574 dst_stride_x,
2575 dst_step_x,
2576 dst_stride_y,
2577 dst_step_y,
2578 dst_stride_z,
2579 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002580 dst_offset_first_element_in_bytes,
2581 src_stride_w,
2582 dst_stride_w);
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002583}
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002584
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002585/** This OpenCL kernel computes the input transform when the kernel size is 1x5 and the output tile is 1x4 for data layout NHWC
2586 *
2587 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2588 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2589 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2590 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2591 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2592 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2593 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002594 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002595 *
Vidhya Sudhan Loganathan71ecf392018-08-31 16:10:16 +01002596 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002597 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2598 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2599 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2600 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2601 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2602 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2603 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2604 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2605 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2606 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2607 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2608 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2609 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2610 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2611 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002612 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2613 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002614 */
2615__kernel void winograd_input_transform_1x4_1x5_stepz1_nhwc(
2616 TENSOR3D_DECLARATION(src),
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002617 TENSOR3D_DECLARATION(dst),
2618 uint src_stride_w,
2619 uint dst_stride_w)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002620{
2621 winograd_input_transform_4x4_5x5_stepz1_nhwc(src_ptr,
2622 src_stride_x,
2623 src_step_x,
2624 src_stride_y,
2625 src_step_y,
2626 src_stride_z,
2627 src_step_z,
2628 src_offset_first_element_in_bytes,
2629 dst_ptr,
2630 dst_stride_x,
2631 dst_step_x,
2632 dst_stride_y,
2633 dst_step_y,
2634 dst_stride_z,
2635 dst_step_z,
Georgios Pinitasc55beee2018-10-23 15:23:23 +01002636 dst_offset_first_element_in_bytes,
2637 src_stride_w,
2638 dst_stride_w);
Giorgio Arenaa50e5e02018-07-02 13:42:23 +01002639}
Michele Di Giorgiof955d512019-02-27 14:26:51 +00002640
2641/** This OpenCL kernel computes the input transform when the kernel size is 1x7 and the output tile is 1x2 for data layout NHWC
2642 *
2643 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=7).
2644 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
2645 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
2646 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2647 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2648 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=7
2649 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
2650 * @note The data type must be passed at compile time using -DDATA_TYPE e.g. -DDATA_TYPE=float. Supported data types: float/half.
2651 *
2652 * @param[in] src_ptr Pointer to the source image. Supported data types: F32/F16
2653 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2654 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2655 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2656 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2657 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2658 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2659 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2660 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2661 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2662 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2663 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2664 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2665 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2666 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2667 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2668 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
2669 * @param[in] dst_stride_w Stride of the destination tensor in W dimension (in bytes)
2670 */
2671__kernel void winograd_input_transform_1x2_1x7_stepz1_nhwc(
2672 TENSOR3D_DECLARATION(src),
2673 TENSOR3D_DECLARATION(dst),
2674 uint src_stride_w,
2675 uint dst_stride_w)
2676{
2677 winograd_input_transform_2x2_7x7_stepz1_nhwc(src_ptr,
2678 src_stride_x,
2679 src_step_x,
2680 src_stride_y,
2681 src_step_y,
2682 src_stride_z,
2683 src_step_z,
2684 src_offset_first_element_in_bytes,
2685 dst_ptr,
2686 dst_stride_x,
2687 dst_step_x,
2688 dst_stride_y,
2689 dst_step_y,
2690 dst_stride_z,
2691 dst_step_z,
2692 dst_offset_first_element_in_bytes,
2693 src_stride_w,
2694 dst_stride_w);
2695}
Georgios Pinitasffb57a02018-10-29 18:01:52 +00002696#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodiced28b7512018-07-06 12:59:28 +01002697#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Michele Di Giorgiof955d512019-02-27 14:26:51 +00002698#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)