blob: f40a969ea020f860edaf99e4c4b3c359e74116fd [file] [log] [blame]
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000026#if defined(NUM_CHANNELS)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +000027
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000028/** This OpenCL kernel performs Winograd filter transform 3x3 when the data format is NCHW and the output tile is 2x2
29 *
30 * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
31 *
32 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
33 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
34 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
35 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
36 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
37 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
38 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
39 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
40 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
41 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
42 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
43 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
44 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
45 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
46 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
47 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
48 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
49 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
50 */
51__kernel void winograd_filter_transform_2x2_3x3_nchw(
52 TENSOR4D_DECLARATION(src),
53 TENSOR3D_DECLARATION(dst))
54{
55 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
56
57 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
58
59 // Load the values from the input tensor
60 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
61 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
62 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
63
64 // Transform the 3x3 tile in a 4x4 tile
65 float4 out0 = 0.0f;
66 float4 out1 = 0.0f;
67 float4 out2 = 0.0f;
68 float4 out3 = 0.0f;
69
70 // Row 0
71 out0.s0 = (w0.s0);
72 out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
73 out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
74 out0.s3 = (w0.s2);
75
76 // Row 1
77 out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
78 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
79 out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
80 out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
81
82 // Row 2
83 out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
84 out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
85 out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
86 out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
87
88 // Row 3
89 out3.s0 = (w2.s0);
90 out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
91 out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
92 out3.s3 = (w2.s2);
93
94 int z = get_global_id(2);
95 int x0 = z / NUM_CHANNELS; // idx filter
96 int y0 = z % NUM_CHANNELS; // idx channel
97
98 // Get output address
99 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
100
101 // Store the 16 values across the 16 channels
102 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
103 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
104 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
105 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
106 *(__global float *)(dst_addr + 4 * dst_stride_z) = out1.s0;
107 *(__global float *)(dst_addr + 5 * dst_stride_z) = out1.s1;
108 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s2;
109 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s3;
110 *(__global float *)(dst_addr + 8 * dst_stride_z) = out2.s0;
111 *(__global float *)(dst_addr + 9 * dst_stride_z) = out2.s1;
112 *(__global float *)(dst_addr + 10 * dst_stride_z) = out2.s2;
113 *(__global float *)(dst_addr + 11 * dst_stride_z) = out2.s3;
114 *(__global float *)(dst_addr + 12 * dst_stride_z) = out3.s0;
115 *(__global float *)(dst_addr + 13 * dst_stride_z) = out3.s1;
116 *(__global float *)(dst_addr + 14 * dst_stride_z) = out3.s2;
117 *(__global float *)(dst_addr + 15 * dst_stride_z) = out3.s3;
118}
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000119
120/** This OpenCL kernel performs Winograd filter transform 3x3 when the data format is NCHW and the output tile is 4x4
121 *
122 * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
123 *
124 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
125 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
126 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
127 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
128 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
129 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
130 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
131 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
132 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
133 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
134 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
135 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
136 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
137 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
138 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
139 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
140 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
141 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
142 */
143__kernel void winograd_filter_transform_4x4_3x3_nchw(
144 TENSOR4D_DECLARATION(src),
145 TENSOR3D_DECLARATION(dst))
146{
147 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
148
149 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
150
151 // Load the values from the input tensor
152 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
153 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
154 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
155
156 // Transform the 3x3 tile in a 6x6 tile
157 float8 out0 = 0.0f;
158 float8 out1 = 0.0f;
159 float8 out2 = 0.0f;
160 float8 out3 = 0.0f;
161 float8 out4 = 0.0f;
162 float8 out5 = 0.0f;
163
164 // Row 0
165 out0.s0 = (w0.s0) / 16.f;
166 out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
167 out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100168 out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
169 out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000170 out0.s5 = (w0.s2) / 4.f;
171
172 // Row 1
173 out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
174 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
175 out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100176 out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
177 out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000178 out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
179
180 // Row 2
181 out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
182 out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
183 out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100184 out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
185 out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000186 out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
187
188 // Row 3
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100189 out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
190 out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
191 out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
192 out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
193 out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
194 out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000195
196 // Row 4
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100197 out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
198 out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
199 out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
200 out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
201 out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
202 out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000203
204 // Row 5
205 out5.s0 = (w2.s0) / 4.f;
206 out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
207 out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100208 out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
209 out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000210 out5.s5 = (w2.s2);
211
212 int z = get_global_id(2);
213 int x0 = z / NUM_CHANNELS; // idx filter
214 int y0 = z % NUM_CHANNELS; // idx channel
215
216 // Get output address
217 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
218
219 // Store the 36 values across the 36 channels
220 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
221 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
222 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
223 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
224 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
225 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
226 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s0;
227 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s1;
228 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s2;
229 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s3;
230 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s4;
231 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s5;
232 *(__global float *)(dst_addr + 12 * dst_stride_z) = out2.s0;
233 *(__global float *)(dst_addr + 13 * dst_stride_z) = out2.s1;
234 *(__global float *)(dst_addr + 14 * dst_stride_z) = out2.s2;
235 *(__global float *)(dst_addr + 15 * dst_stride_z) = out2.s3;
236 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s4;
237 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s5;
238 *(__global float *)(dst_addr + 18 * dst_stride_z) = out3.s0;
239 *(__global float *)(dst_addr + 19 * dst_stride_z) = out3.s1;
240 *(__global float *)(dst_addr + 20 * dst_stride_z) = out3.s2;
241 *(__global float *)(dst_addr + 21 * dst_stride_z) = out3.s3;
242 *(__global float *)(dst_addr + 22 * dst_stride_z) = out3.s4;
243 *(__global float *)(dst_addr + 23 * dst_stride_z) = out3.s5;
244 *(__global float *)(dst_addr + 24 * dst_stride_z) = out4.s0;
245 *(__global float *)(dst_addr + 25 * dst_stride_z) = out4.s1;
246 *(__global float *)(dst_addr + 26 * dst_stride_z) = out4.s2;
247 *(__global float *)(dst_addr + 27 * dst_stride_z) = out4.s3;
248 *(__global float *)(dst_addr + 28 * dst_stride_z) = out4.s4;
249 *(__global float *)(dst_addr + 29 * dst_stride_z) = out4.s5;
250 *(__global float *)(dst_addr + 30 * dst_stride_z) = out5.s0;
251 *(__global float *)(dst_addr + 31 * dst_stride_z) = out5.s1;
252 *(__global float *)(dst_addr + 32 * dst_stride_z) = out5.s2;
253 *(__global float *)(dst_addr + 33 * dst_stride_z) = out5.s3;
254 *(__global float *)(dst_addr + 34 * dst_stride_z) = out5.s4;
255 *(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
256}
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100257
258/** This OpenCL kernel performs Winograd filter transform 5x5 when the data format is NCHW and the output tile is 4x4
259 *
260 * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
261 *
262 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
263 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
264 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
265 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
266 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
267 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
268 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
269 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
270 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
272 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
278 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
280 */
281__kernel void winograd_filter_transform_4x4_5x5_nchw(
282 TENSOR4D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst))
284{
285 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
286
287 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
288
289 // Load the values from the input tensor
290 const char stride_x = 4 * sizeof(float); // Used for accessing the last value in each row
291 const uchar8 stride_y = (uchar8)(0, 1, 2, 3, 4, 0, 0, 0) * (uchar8)src_stride_y;
292
293 float4 w00 = vload4(0, (__global float *)(src_addr + stride_y.s0));
294 float w01 = *((__global float *)(src_addr + stride_y.s0 + stride_x));
295 float4 w10 = vload4(0, (__global float *)(src_addr + stride_y.s1));
296 float w11 = *((__global float *)(src_addr + stride_y.s1 + stride_x));
297 float4 w20 = vload4(0, (__global float *)(src_addr + stride_y.s2));
298 float w21 = *((__global float *)(src_addr + stride_y.s2 + stride_x));
299 float4 w30 = vload4(0, (__global float *)(src_addr + stride_y.s3));
300 float w31 = *((__global float *)(src_addr + stride_y.s3 + stride_x));
301 float4 w40 = vload4(0, (__global float *)(src_addr + stride_y.s4));
302 float w41 = *((__global float *)(src_addr + stride_y.s4 + stride_x));
303
304 // Transform the 3x3 tile in a 8x8 tile
305 float8 out0 = 0.0f;
306 float8 out1 = 0.0f;
307 float8 out2 = 0.0f;
308 float8 out3 = 0.0f;
309 float8 out4 = 0.0f;
310 float8 out5 = 0.0f;
311 float8 out6 = 0.0f;
312 float8 out7 = 0.0f;
313
314 // Row 0
315 out0.s0 = w00.s0;
316 out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
317 out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
318 out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
319 out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
320 out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
321 out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
322 out0.s7 = w01;
323
324 // Row 1
325 out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
326 out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
327 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
328 out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
329 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
330 out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
331 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
332 out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
333 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
334 out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
335 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
336 out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
337 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
338 out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
339
340 // Row 2
341 out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
342 out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
343 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
344 out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
345 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
346 out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
347 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
348 out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
349 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
350 out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
351 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
352 out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
353 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
354 out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
355
356 // Row 3
357 out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
358 out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
359 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
360 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
361 out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
362 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
363 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
364 out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
365 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
366 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
367 out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
368 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
369 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
370 out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
371 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
372 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
373 out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
374 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
375 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
376 out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
377
378 // Row 4
379 out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
380 out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
381 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
382 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
383 out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
384 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
385 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
386 out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
387 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
388 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
389 out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
390 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
391 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
392 out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
393 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
394 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
395 out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
396 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
397 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
398 out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
399
400 // Row 5
401 out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
402 out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
403 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
404 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
405 out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
406 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
407 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
408 out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s0 + w40.s1) + 4.f *
409 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
410 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
411 out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s0 + w40.s1) + 4.f *
412 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
413 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
414 out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s0 + w40.s1) + 4.f *
415 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
416 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
417 out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s0 + w40.s1) + 4.f *
418 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
419 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
420 out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
421
422 // Row 6
423 out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
424 out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
425 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
426 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
427 out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
428 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
429 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
430 out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s0 + w40.s1) + 4.f *
431 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
432 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
433 out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s0 + w40.s1) + 4.f *
434 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
435 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
436 out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s0 + w40.s1) + 4.f *
437 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
438 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
439 out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s0 + w40.s1) + 4.f *
440 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
441 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
442 out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
443
444 // Row 7
445 out7.s0 = w40.s0;
446 out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
447 out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
448 out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
449 out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
450 out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
451 out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
452 out7.s7 = w41;
453
454 int z = get_global_id(2);
455 int x0 = z / NUM_CHANNELS; // idx filter
456 int y0 = z % NUM_CHANNELS; // idx channel
457
458 // Get output address
459 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
460
461 // Store the 64 values across the 64 channels
462 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
463 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
464 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
465 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
466 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
467 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
468 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
469 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
470 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
471 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
472 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
473 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
474 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
475 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
476 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
477 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
478 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
479 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
480 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
481 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
482 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
483 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
484 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
485 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
486 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
487 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
488 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
489 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
490 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
491 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
492 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
493 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
494 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
495 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
496 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
497 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
498 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
499 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
500 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
501 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
502 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
503 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
504 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
505 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
506 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
507 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
508 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
509 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
510 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
511 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
512 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
513 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
514 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
515 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
516 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
517 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
518 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
519 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
520 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
521 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
522 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
523 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
524 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
525 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
526}
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000527#endif // defined(NUM_CHANNELS)
528
529#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +0000530/** This OpenCL kernel computes the input transform when the kernel size is 3x3 and the output tile is 2x2
531 *
532 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
533 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
534 *
535 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
536 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
537 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
538 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
539 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
540 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
541 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
542 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
543 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
544 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
545 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
546 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
547 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
548 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
549 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
550 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
551 */
552__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
553 TENSOR3D_DECLARATION(src),
554 TENSOR3D_DECLARATION(dst))
555{
556 int x = get_global_id(0);
557 int y = get_global_id(1);
558 int z = get_global_id(2);
559
560 // Compute input address
561 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 2 * src_stride_x + y * 2 * src_stride_y + z * src_stride_z;
562
563 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
564
565 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
566 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
567 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
568 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
569
570 float4 tmp0 = in_row0 - in_row2;
571 float4 tmp1 = in_row1 + in_row2;
572 float4 tmp2 = in_row2 - in_row1;
573 float4 tmp3 = in_row1 - in_row3;
574
575 float out00 = tmp0.s0 - tmp0.s2;
576 float out01 = tmp0.s1 + tmp0.s2;
577 float out02 = tmp0.s2 - tmp0.s1;
578 float out03 = tmp0.s1 - tmp0.s3;
579
580 float out10 = tmp1.s0 - tmp1.s2;
581 float out11 = tmp1.s1 + tmp1.s2;
582 float out12 = tmp1.s2 - tmp1.s1;
583 float out13 = tmp1.s1 - tmp1.s3;
584
585 float out20 = tmp2.s0 - tmp2.s2;
586 float out21 = tmp2.s1 + tmp2.s2;
587 float out22 = tmp2.s2 - tmp2.s1;
588 float out23 = tmp2.s1 - tmp2.s3;
589
590 float out30 = tmp3.s0 - tmp3.s2;
591 float out31 = tmp3.s1 + tmp3.s2;
592 float out32 = tmp3.s2 - tmp3.s1;
593 float out33 = tmp3.s1 - tmp3.s3;
594
595 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
596
597 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out00;
598 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out01;
599 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out02;
600 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out03;
601 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out10;
602 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out11;
603 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out12;
604 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out13;
605 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out20;
606 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out21;
607 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out22;
608 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out23;
609 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out30;
610 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out31;
611 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out32;
612 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out33;
613}
614
615/** This OpenCL kernel computes the input transform when the kernel size is 3x3, the output tile is 2x2 and the number of channels is multiple of 2
616 *
617 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
618 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
619 *
620 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
621 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
622 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
623 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
624 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
625 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
626 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
627 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
628 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
629 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
630 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
631 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
632 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
633 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
634 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
635 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
636 */
637__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
638 TENSOR3D_DECLARATION(src),
639 TENSOR3D_DECLARATION(dst))
640{
641 int x = get_global_id(0);
642 int y = get_global_id(1);
643 int z = get_global_id(2) * 2;
644
645 // Compute input address
646 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 2 * src_stride_x + y * 2 * src_stride_y + z * src_stride_z;
647
648 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
649
650 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
651 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
652 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
653 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
654
655 src_addr += src_stride_z;
656 float4 in_row4 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
657 float4 in_row5 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
658 float4 in_row6 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
659 float4 in_row7 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
660
661 float4 tmp0 = in_row0 - in_row2;
662 float4 tmp1 = in_row1 + in_row2;
663 float4 tmp2 = in_row2 - in_row1;
664 float4 tmp3 = in_row1 - in_row3;
665
666 float4 tmp4 = in_row4 - in_row6;
667 float4 tmp5 = in_row5 + in_row6;
668 float4 tmp6 = in_row6 - in_row5;
669 float4 tmp7 = in_row5 - in_row7;
670
671 float2 out00 = (float2)(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
672 float2 out01 = (float2)(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
673 float2 out02 = (float2)(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
674 float2 out03 = (float2)(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
675
676 float2 out10 = (float2)(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
677 float2 out11 = (float2)(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
678 float2 out12 = (float2)(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
679 float2 out13 = (float2)(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
680
681 float2 out20 = (float2)(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
682 float2 out21 = (float2)(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
683 float2 out22 = (float2)(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
684 float2 out23 = (float2)(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
685
686 float2 out30 = (float2)(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
687 float2 out31 = (float2)(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
688 float2 out32 = (float2)(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
689 float2 out33 = (float2)(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
690
691 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
692
693 vstore2(out00, 0, (__global float *)(dst_addr + 0 * dst_stride_z));
694 vstore2(out01, 0, (__global float *)(dst_addr + 1 * dst_stride_z));
695 vstore2(out02, 0, (__global float *)(dst_addr + 2 * dst_stride_z));
696 vstore2(out03, 0, (__global float *)(dst_addr + 3 * dst_stride_z));
697 vstore2(out10, 0, (__global float *)(dst_addr + 4 * dst_stride_z));
698 vstore2(out11, 0, (__global float *)(dst_addr + 5 * dst_stride_z));
699 vstore2(out12, 0, (__global float *)(dst_addr + 6 * dst_stride_z));
700 vstore2(out13, 0, (__global float *)(dst_addr + 7 * dst_stride_z));
701 vstore2(out20, 0, (__global float *)(dst_addr + 8 * dst_stride_z));
702 vstore2(out21, 0, (__global float *)(dst_addr + 9 * dst_stride_z));
703 vstore2(out22, 0, (__global float *)(dst_addr + 10 * dst_stride_z));
704 vstore2(out23, 0, (__global float *)(dst_addr + 11 * dst_stride_z));
705 vstore2(out30, 0, (__global float *)(dst_addr + 12 * dst_stride_z));
706 vstore2(out31, 0, (__global float *)(dst_addr + 13 * dst_stride_z));
707 vstore2(out32, 0, (__global float *)(dst_addr + 14 * dst_stride_z));
708 vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z));
709}
Giorgio Arenafe5ef382018-04-17 10:14:10 +0100710
Gian Marco Iodicee52a3002018-04-11 15:59:10 +0100711/** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data format is NCHW
712 *
713 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
714 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
715 *
716 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
717 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
718 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
719 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
720 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
721 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
722 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
723 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
724 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
725 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
726 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
727 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
728 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
729 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
730 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
731 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
732 */
733__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
734 TENSOR3D_DECLARATION(src),
735 TENSOR3D_DECLARATION(dst))
736{
737 int x = get_global_id(0);
738 int y = get_global_id(1);
739 int z = get_global_id(2);
740
741 // Compute input address
742 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
743
744 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
745
746 // Row4
747 float4 d40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
748 float2 d41 = vload2(2, (__global float *)(src_addr + 4 * src_stride_y));
749
750 float k0 = d41.s0;
751 float k1 = d41.s0;
752 float k2 = d41.s0;
753 float k3 = d41.s0;
754 float k4 = d41.s0;
755 float k5 = 0.0f;
756
757 k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
758 k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
759 k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
760 k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
761 k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
762 k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
763
764 // Row0
765 float4 d00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
766 float2 d01 = vload2(2, (__global float *)(src_addr + 0 * src_stride_y));
767
768 // Row2
769 float4 d20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
770 float2 d21 = vload2(2, (__global float *)(src_addr + 2 * src_stride_y));
771
772 // Compute destination address
773 __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y);
774
775 uint dst_plane_stride = dst_stride_z / sizeof(float);
776
777 float out0 = k0;
778 float out1 = k1;
779 float out2 = k2;
780 float out3 = k3;
781 float out4 = k4;
782 float out5 = k5;
783 float out6 = k0;
784 float out7 = k1;
785 float out8 = k2;
786 float out9 = k3;
787 float out10 = k4;
788 float out11 = k5;
789 float out12 = k0;
790 float out13 = k1;
791 float out14 = k2;
792 float out15 = k3;
793 float out16 = k4;
794 float out17 = k5;
795 float out18 = k0;
796 float out19 = k1;
797 float out20 = k2;
798 float out21 = k3;
799 float out22 = k4;
800 float out23 = k5;
801 float out24 = k0;
802 float out25 = k1;
803 float out26 = k2;
804 float out27 = k3;
805 float out28 = k4;
806 float out29 = k5;
807
808 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
809 out0 += 16.0f * d00.s0 - 20.0f * d00.s2 - 20.0f * d20.s0 + 25.0f * d20.s2 + 4.0f * d01.s0 - 5.0f * d21.s0;
810 out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
811 out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 - 20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
812 out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
813 out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 - 10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 + 4.0f * d01.s0 - 5.0f * d21.s0;
814 out5 += 16.0f * d00.s1 - 20.0f * d00.s3 - 20.0f * d20.s1 + 4.0f * d01.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
815
816 *(dst_addr) = out0;
817 dst_addr += dst_plane_stride;
818 *(dst_addr) = out1;
819 dst_addr += dst_plane_stride;
820 *(dst_addr) = out2;
821 dst_addr += dst_plane_stride;
822 *(dst_addr) = out3;
823 dst_addr += dst_plane_stride;
824 *(dst_addr) = out4;
825 dst_addr += dst_plane_stride;
826 *(dst_addr) = out5;
827 dst_addr += dst_plane_stride;
828
829 // Row1
830 float4 d10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
831 float2 d11 = vload2(2, (__global float *)(src_addr + 1 * src_stride_y));
832
833 // Row3
834 float4 d30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
835 float2 d31 = vload2(2, (__global float *)(src_addr + 3 * src_stride_y));
836
837 // Compute common parts for the channels between [6, 29]
838 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
839 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
840 float part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
841 float part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
842 float part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
843 float part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
844 float part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
845 float part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
846 float part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
847 float part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
848 float part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
849 float part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
850 float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
851 float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
852
853 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
854 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
855 float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
856 float part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
857 float part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
858 float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
859 float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
860 float part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
861 float part18 = part6 * 0.25f; // d20.s2 - d21.s0
862 float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
863 float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
864 float part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
865 float part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
866 float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
867
868 out6 += part0 - part1;
869 out12 += part0 + part1;
870 out7 += part2 + part3 + part4 + part5;
871 out8 += part2 - part3 + part4 - part5;
872 out13 += part2 + part3 - part4 - part5;
873 out14 += part2 - part3 - part4 + part5;
874 out9 += part6 + part7 + part8 + part9;
875 out10 += part6 - part7 + part8 - part9;
876 out15 += part6 - part7 - part8 + part9;
877 out16 += part6 + part7 - part8 - part9;
878 out11 += part10 + part11;
879 out17 += part10 - part11;
880
881 out18 += part13 - part12;
882 out24 += part13 + part12;
883 out19 += part14 + part15 + part16 + part17;
884 out20 += part14 - part15 + part16 - part17;
885 out25 += part14 - part15 - part16 + part17;
886 out26 += part14 + part15 - part16 - part17;
887 out21 += part18 + part19 + part20 + part21;
888 out22 += part18 - part19 + part20 - part21;
889 out27 += part18 - part19 - part20 + part21;
890 out28 += part18 + part19 - part20 - part21;
891 out23 += part22 + part23;
892 out29 += part22 - part23;
893
894 *(dst_addr) = out6;
895 dst_addr += dst_plane_stride;
896 *(dst_addr) = out7;
897 dst_addr += dst_plane_stride;
898 *(dst_addr) = out8;
899 dst_addr += dst_plane_stride;
900 *(dst_addr) = out9;
901 dst_addr += dst_plane_stride;
902 *(dst_addr) = out10;
903 dst_addr += dst_plane_stride;
904 *(dst_addr) = out11;
905 dst_addr += dst_plane_stride;
906 *(dst_addr) = out12;
907 dst_addr += dst_plane_stride;
908 *(dst_addr) = out13;
909 dst_addr += dst_plane_stride;
910 *(dst_addr) = out14;
911 dst_addr += dst_plane_stride;
912 *(dst_addr) = out15;
913 dst_addr += dst_plane_stride;
914 *(dst_addr) = out16;
915 dst_addr += dst_plane_stride;
916 *(dst_addr) = out17;
917 dst_addr += dst_plane_stride;
918
919 *(dst_addr) = out18;
920 dst_addr += dst_plane_stride;
921 *(dst_addr) = out19;
922 dst_addr += dst_plane_stride;
923 *(dst_addr) = out20;
924 dst_addr += dst_plane_stride;
925 *(dst_addr) = out21;
926 dst_addr += dst_plane_stride;
927 *(dst_addr) = out22;
928 dst_addr += dst_plane_stride;
929 *(dst_addr) = out23;
930 dst_addr += dst_plane_stride;
931 *(dst_addr) = out24;
932 dst_addr += dst_plane_stride;
933 *(dst_addr) = out25;
934 dst_addr += dst_plane_stride;
935 *(dst_addr) = out26;
936 dst_addr += dst_plane_stride;
937 *(dst_addr) = out27;
938 dst_addr += dst_plane_stride;
939 *(dst_addr) = out28;
940 dst_addr += dst_plane_stride;
941 *(dst_addr) = out29;
942 dst_addr += dst_plane_stride;
943
944 // Row5
945 float4 d50 = vload4(0, (__global float *)(src_addr + 5 * src_stride_y));
946 float2 d51 = vload2(2, (__global float *)(src_addr + 5 * src_stride_y));
947
948 // Channels [30, 35]
949 out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
950 out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
951 out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
952 out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
953 out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
954 out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
955
956 *(dst_addr) = out0;
957 dst_addr += dst_plane_stride;
958 *(dst_addr) = out1;
959 dst_addr += dst_plane_stride;
960 *(dst_addr) = out2;
961 dst_addr += dst_plane_stride;
962 *(dst_addr) = out3;
963 dst_addr += dst_plane_stride;
964 *(dst_addr) = out4;
965 dst_addr += dst_plane_stride;
966 *(dst_addr) = out5;
967 dst_addr += dst_plane_stride;
968}
969
Giorgio Arenafe5ef382018-04-17 10:14:10 +0100970#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
971 ({ \
972 comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
973 comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
974 comm_fact.s2 = 2.5f * tmp.s3; \
975 comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
976 comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
977 comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
978 comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
979 \
980 out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
981 out.s1 = comm_fact.s0 + comm_fact.s1; \
982 out.s2 = comm_fact.s0 - comm_fact.s1; \
983 out.s3 = comm_fact.s3 + comm_fact.s4; \
984 out.s4 = comm_fact.s4 - comm_fact.s3; \
985 out.s5 = comm_fact.s5 + comm_fact.s6; \
986 out.s6 = comm_fact.s5 - comm_fact.s6; \
987 out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
988 })
989
990/** This OpenCL kernel computes the input transform when the kernel size is 5x5 and the output tile is 4x4
991 *
992 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
993 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
994 *
995 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
996 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
997 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
998 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
999 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1000 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1001 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1002 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1003 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1004 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1005 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1006 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1007 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1008 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1009 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1010 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1011 */
1012__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
1013 TENSOR3D_DECLARATION(src),
1014 TENSOR3D_DECLARATION(dst))
1015{
1016 int x = get_global_id(0);
1017 int y = get_global_id(1);
1018 int z = get_global_id(2);
1019
1020 // Compute input address
1021 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
1022
1023 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
1024
1025 // Load 8x8 input tile
1026 const float8 in_row0 = vload8(0, (__global float *)(src_addr + 0 * src_stride_y));
1027 const float8 in_row1 = vload8(0, (__global float *)(src_addr + 1 * src_stride_y));
1028 const float8 in_row2 = vload8(0, (__global float *)(src_addr + 2 * src_stride_y));
1029 const float8 in_row3 = vload8(0, (__global float *)(src_addr + 3 * src_stride_y));
1030 const float8 in_row4 = vload8(0, (__global float *)(src_addr + 4 * src_stride_y));
1031 const float8 in_row5 = vload8(0, (__global float *)(src_addr + 5 * src_stride_y));
1032 const float8 in_row6 = vload8(0, (__global float *)(src_addr + 6 * src_stride_y));
1033 const float8 in_row7 = vload8(0, (__global float *)(src_addr + 7 * src_stride_y));
1034
1035 // Calculate common factors for intermediate tensor
1036 float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
1037 float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
1038 float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
1039
1040 // Calculate intermediate tensor and reuse common factor vectors
1041 const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
1042 const float8 tmp1 = comm_fact0 + comm_fact1;
1043 const float8 tmp2 = comm_fact0 - comm_fact1;
1044
1045 comm_fact0 = 2.5f * in_row3;
1046 comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
1047
1048 const float8 tmp3 = comm_fact1 + comm_fact2;
1049 const float8 tmp4 = comm_fact2 - comm_fact1;
1050
1051 comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
1052 comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
1053
1054 const float8 tmp5 = comm_fact1 + comm_fact2;
1055 const float8 tmp6 = comm_fact2 - comm_fact1;
1056 const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
1057
1058 // Calculate output rows (reuse comm_fact0 vector)
1059 float8 out0, out1, out2, out3, out4, out5, out6, out7;
1060
1061 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
1062 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
1063 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
1064 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
1065 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
1066 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
1067 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
1068 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
1069
1070 // Store values across the 64 channels
1071 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
1072
1073 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
1074 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
1075 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
1076 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
1077 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
1078 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
1079 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
1080 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
1081 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
1082 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
1083 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
1084 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
1085 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
1086 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
1087 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
1088 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
1089 *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
1090 *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
1091 *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
1092 *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
1093 *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
1094 *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
1095 *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
1096 *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
1097 *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
1098 *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
1099 *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
1100 *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
1101 *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
1102 *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
1103 *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
1104 *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
1105 *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
1106 *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
1107 *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
1108 *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
1109 *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
1110 *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
1111 *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
1112 *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
1113 *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
1114 *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
1115 *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
1116 *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
1117 *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
1118 *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
1119 *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
1120 *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
1121 *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
1122 *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
1123 *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
1124 *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
1125 *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
1126 *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
1127 *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
1128 *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
1129 *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
1130 *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
1131 *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
1132 *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
1133 *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
1134 *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
1135 *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
1136 *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
1137}
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001138#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001139
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001140#if defined(NUM_TILES_X)
1141/** This OpenCL kernel performs Winograd output transform when the output tile is 2x2, the filter size 3x3 and the data format is NCHW
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001142 *
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001143 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001144 *
1145 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1146 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1147 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1148 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1149 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1150 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1151 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001152 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1153 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1154 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1155 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1156 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1157 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1158 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1159 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1160 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1161 */
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001162__kernel void winograd_output_transform_2x2_3x3_nchw(
1163 TENSOR3D_DECLARATION(src),
1164 TENSOR3D_DECLARATION(dst)
1165#if defined(HAS_BIAS)
1166 ,
1167 VECTOR_DECLARATION(bias)
1168#endif // defined(HAS_BIAS)
1169)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001170{
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001171 // Each thread stores a 2x2 tile
1172 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001173
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001174 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001175
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001176 // Load the values across the 16 channels to compose the 4x4 tile
1177 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
1178 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
1179 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
1180 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001181
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001182 float d10 = *((__global float *)(src_addr + 4 * src_stride_z));
1183 float d11 = *((__global float *)(src_addr + 5 * src_stride_z));
1184 float d12 = *((__global float *)(src_addr + 6 * src_stride_z));
1185 float d13 = *((__global float *)(src_addr + 7 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001186
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001187 float d20 = *((__global float *)(src_addr + 8 * src_stride_z));
1188 float d21 = *((__global float *)(src_addr + 9 * src_stride_z));
1189 float d22 = *((__global float *)(src_addr + 10 * src_stride_z));
1190 float d23 = *((__global float *)(src_addr + 11 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001191
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001192 float d30 = *((__global float *)(src_addr + 12 * src_stride_z));
1193 float d31 = *((__global float *)(src_addr + 13 * src_stride_z));
1194 float d32 = *((__global float *)(src_addr + 14 * src_stride_z));
1195 float d33 = *((__global float *)(src_addr + 15 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001196
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001197 // Compute the 2x2 output tile
1198 float k0 = d01 + d11 + d21;
1199 float k1 = d02 + d12 + d22;
1200 float k2 = d11 - d21 - d31;
1201 float k3 = d12 - d22 - d32;
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001202
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001203 // out00 = d00 + d10 + d20 + d01 + d11 + d21 + d02 + d12 + d22
1204 // out01 = d01 + d11 + d21 - (d02 + d12 + d22) - (d03 + d13 + d23)
1205 // out10 = d10 - d20 - d30 + (d11 - d21 - d31) + (d12 - d22 - d32)
1206 // out11 = d11 - d21 - d31 - (d12 - d22 - d32) - (d13 - d23 - d33)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001207
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001208 float out00 = d10;
1209 float out01 = -d13;
1210 float out10 = d10;
1211 float out11 = -d13;
1212
1213 out00 += d00 + d20 + k0 + k1;
1214 out01 += k0 - k1 - (d03 + d23);
1215 out10 += -d20 - d30 + k2 + k3;
1216 out11 += k2 - k3 + d23 + d33;
1217
1218 int y_in = get_global_id(1);
1219 int x_out = (y_in % NUM_TILES_X) * 2;
1220 int y_out = (y_in / NUM_TILES_X) * 2;
1221 int z_out = get_global_id(0);
1222
1223#if defined(HAS_BIAS)
1224 // Add bias
1225 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
1226
1227 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
1228
1229 out00 += (float)b;
1230 out01 += (float)b;
1231 out10 += (float)b;
1232 out11 += (float)b;
1233#endif // defined(HAS_BIAS)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001234
1235 // Get output address
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001236 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001237
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001238 // Store the 2x2 output tile
1239 vstore2((float2)(out00, out01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1240 vstore2((float2)(out10, out11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00001241}
Giorgio Arenadd038702018-04-16 11:20:11 +01001242
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001243/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data format is NCHW
1244 *
1245 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
1246 *
1247 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1248 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1249 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1250 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1251 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1252 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1253 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1254 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1255 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1256 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1257 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1258 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1259 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1260 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1261 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1262 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1263 */
1264__kernel void winograd_output_transform_4x4_3x3_nchw(
1265 TENSOR3D_DECLARATION(src),
1266 TENSOR3D_DECLARATION(dst)
1267#if defined(HAS_BIAS)
1268 ,
1269 VECTOR_DECLARATION(bias)
1270#endif // defined(HAS_BIAS)
1271)
1272{
1273 // Each thread stores a 4x4 tile
1274 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
1275
1276 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
1277
1278 // Load the values across the 36 channels to compose the 6x6 tile
1279 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
1280 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
1281 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
1282 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
1283 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
1284 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
1285
1286 float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
1287 float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
1288 float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
1289 float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
1290 float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
1291 float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
1292
1293 float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
1294 float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
1295 float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
1296 float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
1297 float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
1298 float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
1299
1300 float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
1301 float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
1302 float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
1303 float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
1304 float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
1305 float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
1306
1307 float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
1308 float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
1309 float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
1310 float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
1311 float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
1312 float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
1313
1314 float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
1315 float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
1316 float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
1317 float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
1318 float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
1319 float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
1320
1321 // Compute out00, out01, out02 and out03
1322 float out00 = d01 + d21 + d41 + d11 + d31;
1323 float out01 = d01 + d21 + d41 + d11 + d31;
1324 float out02 = d01 + d21 + d41 + d11 + d31;
1325 float out03 = d01 + d21 + d41 + d11 + d31;
1326
1327 float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
1328 float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
1329
1330 out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
1331 out01 += k1 - d02 - d12 - d22 - d32 - d42;
1332 out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
1333 out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
1334
1335 // Compute out10, out11, out12 and out13
1336 float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
1337 float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
1338 float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
1339 float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
1340
1341 k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
1342 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
1343
1344 out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
1345 out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
1346 out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
1347 out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
1348
1349 // Compute out20, out21, out22 and out23
1350 float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
1351 float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
1352 float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
1353 float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
1354
1355 k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
1356 k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
1357
1358 out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
1359 out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
1360 out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
1361 out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
1362
1363 // Compute out30, out31, out32 and out33
1364 float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
1365 float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
1366 float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
1367 float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
1368
1369 k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
1370 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
1371
1372 out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
1373 out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
1374 out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
1375 out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
1376
1377 int y_in = get_global_id(1);
1378 int x_out = (y_in % NUM_TILES_X) * 4;
1379 int y_out = (y_in / NUM_TILES_X) * 4;
1380 int z_out = get_global_id(0);
1381
1382#if defined(HAS_BIAS)
1383 // Add bias
1384 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
1385
1386 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
1387
1388 out00 += (float)b;
1389 out01 += (float)b;
1390 out02 += (float)b;
1391 out03 += (float)b;
1392
1393 out10 += (float)b;
1394 out11 += (float)b;
1395 out12 += (float)b;
1396 out13 += (float)b;
1397
1398 out20 += (float)b;
1399 out21 += (float)b;
1400 out22 += (float)b;
1401 out23 += (float)b;
1402
1403 out30 += (float)b;
1404 out31 += (float)b;
1405 out32 += (float)b;
1406 out33 += (float)b;
1407
1408#endif // defined(HAS_BIAS)
1409
1410 // Get output address
1411 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
1412
1413 // Store the 4x4 output tile
1414 vstore4((float4)(out00, out01, out02, out03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
1415 vstore4((float4)(out10, out11, out12, out13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
1416 vstore4((float4)(out20, out21, out22, out23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
1417 vstore4((float4)(out30, out31, out32, out33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
1418}
1419
Giorgio Arenadd038702018-04-16 11:20:11 +01001420#define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \
1421 ({ \
1422 comm_fact.s0 = d1 + d2; \
1423 comm_fact.s1 = d3 + d4; \
1424 comm_fact.s2 = d5 + d6; \
1425 \
1426 col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \
1427 col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \
1428 \
1429 comm_fact.s0 = d1 - d2; \
1430 comm_fact.s1 = d3 - d4; \
1431 comm_fact.s2 = d5 - d6; \
1432 \
1433 col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \
1434 col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \
1435 })
1436
1437/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 5x5 and the data format is NCHW
1438 *
1439 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
1440 *
1441 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1442 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1443 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1444 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1445 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1446 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1447 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1448 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1449 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1450 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1451 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1452 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1453 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1454 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1455 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1456 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1457 */
1458__kernel void winograd_output_transform_4x4_5x5_nchw(
1459 TENSOR3D_DECLARATION(src),
1460 TENSOR3D_DECLARATION(dst)
1461#if defined(HAS_BIAS)
1462 ,
1463 VECTOR_DECLARATION(bias)
1464#endif // defined(HAS_BIAS)
1465)
1466{
1467 // Each thread stores a 4x4 tile
1468 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
1469
1470 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
1471
1472 // Load the values across the 64 channels to compose the 8x8 input tile
1473 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
1474 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
1475 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
1476 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
1477 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
1478 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
1479 float d06 = *((__global float *)(src_addr + 6 * src_stride_z));
1480 float d07 = *((__global float *)(src_addr + 7 * src_stride_z));
1481
1482 float d10 = *((__global float *)(src_addr + 8 * src_stride_z));
1483 float d11 = *((__global float *)(src_addr + 9 * src_stride_z));
1484 float d12 = *((__global float *)(src_addr + 10 * src_stride_z));
1485 float d13 = *((__global float *)(src_addr + 11 * src_stride_z));
1486 float d14 = *((__global float *)(src_addr + 12 * src_stride_z));
1487 float d15 = *((__global float *)(src_addr + 13 * src_stride_z));
1488 float d16 = *((__global float *)(src_addr + 14 * src_stride_z));
1489 float d17 = *((__global float *)(src_addr + 15 * src_stride_z));
1490
1491 float d20 = *((__global float *)(src_addr + 16 * src_stride_z));
1492 float d21 = *((__global float *)(src_addr + 17 * src_stride_z));
1493 float d22 = *((__global float *)(src_addr + 18 * src_stride_z));
1494 float d23 = *((__global float *)(src_addr + 19 * src_stride_z));
1495 float d24 = *((__global float *)(src_addr + 20 * src_stride_z));
1496 float d25 = *((__global float *)(src_addr + 21 * src_stride_z));
1497 float d26 = *((__global float *)(src_addr + 22 * src_stride_z));
1498 float d27 = *((__global float *)(src_addr + 23 * src_stride_z));
1499
1500 float d30 = *((__global float *)(src_addr + 24 * src_stride_z));
1501 float d31 = *((__global float *)(src_addr + 25 * src_stride_z));
1502 float d32 = *((__global float *)(src_addr + 26 * src_stride_z));
1503 float d33 = *((__global float *)(src_addr + 27 * src_stride_z));
1504 float d34 = *((__global float *)(src_addr + 28 * src_stride_z));
1505 float d35 = *((__global float *)(src_addr + 29 * src_stride_z));
1506 float d36 = *((__global float *)(src_addr + 30 * src_stride_z));
1507 float d37 = *((__global float *)(src_addr + 31 * src_stride_z));
1508
1509 float d40 = *((__global float *)(src_addr + 32 * src_stride_z));
1510 float d41 = *((__global float *)(src_addr + 33 * src_stride_z));
1511 float d42 = *((__global float *)(src_addr + 34 * src_stride_z));
1512 float d43 = *((__global float *)(src_addr + 35 * src_stride_z));
1513 float d44 = *((__global float *)(src_addr + 36 * src_stride_z));
1514 float d45 = *((__global float *)(src_addr + 37 * src_stride_z));
1515 float d46 = *((__global float *)(src_addr + 38 * src_stride_z));
1516 float d47 = *((__global float *)(src_addr + 39 * src_stride_z));
1517
1518 float d50 = *((__global float *)(src_addr + 40 * src_stride_z));
1519 float d51 = *((__global float *)(src_addr + 41 * src_stride_z));
1520 float d52 = *((__global float *)(src_addr + 42 * src_stride_z));
1521 float d53 = *((__global float *)(src_addr + 43 * src_stride_z));
1522 float d54 = *((__global float *)(src_addr + 44 * src_stride_z));
1523 float d55 = *((__global float *)(src_addr + 45 * src_stride_z));
1524 float d56 = *((__global float *)(src_addr + 46 * src_stride_z));
1525 float d57 = *((__global float *)(src_addr + 47 * src_stride_z));
1526
1527 float d60 = *((__global float *)(src_addr + 48 * src_stride_z));
1528 float d61 = *((__global float *)(src_addr + 49 * src_stride_z));
1529 float d62 = *((__global float *)(src_addr + 50 * src_stride_z));
1530 float d63 = *((__global float *)(src_addr + 51 * src_stride_z));
1531 float d64 = *((__global float *)(src_addr + 52 * src_stride_z));
1532 float d65 = *((__global float *)(src_addr + 53 * src_stride_z));
1533 float d66 = *((__global float *)(src_addr + 54 * src_stride_z));
1534 float d67 = *((__global float *)(src_addr + 55 * src_stride_z));
1535
1536 float d70 = *((__global float *)(src_addr + 56 * src_stride_z));
1537 float d71 = *((__global float *)(src_addr + 57 * src_stride_z));
1538 float d72 = *((__global float *)(src_addr + 58 * src_stride_z));
1539 float d73 = *((__global float *)(src_addr + 59 * src_stride_z));
1540 float d74 = *((__global float *)(src_addr + 60 * src_stride_z));
1541 float d75 = *((__global float *)(src_addr + 61 * src_stride_z));
1542 float d76 = *((__global float *)(src_addr + 62 * src_stride_z));
1543 float d77 = *((__global float *)(src_addr + 63 * src_stride_z));
1544
1545 // Compute the 8x4 intermediate tensor
1546 float4 comm_fact0, comm_fact1, comm_fact2;
1547 float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
1548
1549 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
1550 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
1551 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
1552 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
1553 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
1554 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
1555 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
1556 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
1557
1558 // Compute the 4x4 output tile
1559 comm_fact0 = tmp_col1 + tmp_col2;
1560 comm_fact1 = tmp_col3 + tmp_col4;
1561 comm_fact2 = tmp_col5 + tmp_col6;
1562
1563 float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
1564 float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
1565
1566 comm_fact0 = tmp_col1 - tmp_col2;
1567 comm_fact1 = tmp_col3 - tmp_col4;
1568 comm_fact2 = tmp_col5 - tmp_col6;
1569
1570 float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
1571 float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
1572
1573 int y_in = get_global_id(1);
1574 int x_out = (y_in % NUM_TILES_X) * 4;
1575 int y_out = (y_in / NUM_TILES_X) * 4;
1576 int z_out = get_global_id(0);
1577
1578#if defined(HAS_BIAS)
1579 // Add bias
1580 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
1581
1582 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
1583
1584 out_col0 += (float4)b;
1585 out_col1 += (float4)b;
1586 out_col2 += (float4)b;
1587 out_col3 += (float4)b;
1588#endif // defined(HAS_BIAS)
1589
1590 // Get output address
1591 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
1592
1593 // Store the 4x4 output tile
1594 *(__global float *)(dst_addr + 0 * dst_stride_x + 0 * dst_stride_y) = out_col0.s0;
1595 *(__global float *)(dst_addr + 1 * dst_stride_x + 0 * dst_stride_y) = out_col1.s0;
1596 *(__global float *)(dst_addr + 2 * dst_stride_x + 0 * dst_stride_y) = out_col2.s0;
1597 *(__global float *)(dst_addr + 3 * dst_stride_x + 0 * dst_stride_y) = out_col3.s0;
1598 *(__global float *)(dst_addr + 0 * dst_stride_x + 1 * dst_stride_y) = out_col0.s1;
1599 *(__global float *)(dst_addr + 1 * dst_stride_x + 1 * dst_stride_y) = out_col1.s1;
1600 *(__global float *)(dst_addr + 2 * dst_stride_x + 1 * dst_stride_y) = out_col2.s1;
1601 *(__global float *)(dst_addr + 3 * dst_stride_x + 1 * dst_stride_y) = out_col3.s1;
1602 *(__global float *)(dst_addr + 0 * dst_stride_x + 2 * dst_stride_y) = out_col0.s2;
1603 *(__global float *)(dst_addr + 1 * dst_stride_x + 2 * dst_stride_y) = out_col1.s2;
1604 *(__global float *)(dst_addr + 2 * dst_stride_x + 2 * dst_stride_y) = out_col2.s2;
1605 *(__global float *)(dst_addr + 3 * dst_stride_x + 2 * dst_stride_y) = out_col3.s2;
1606 *(__global float *)(dst_addr + 0 * dst_stride_x + 3 * dst_stride_y) = out_col0.s3;
1607 *(__global float *)(dst_addr + 1 * dst_stride_x + 3 * dst_stride_y) = out_col1.s3;
1608 *(__global float *)(dst_addr + 2 * dst_stride_x + 3 * dst_stride_y) = out_col2.s3;
1609 *(__global float *)(dst_addr + 3 * dst_stride_x + 3 * dst_stride_y) = out_col3.s3;
1610}
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001611#endif // defined(NUM_TILES_X)