blob: cda23b0155d8c971e9b99b18b42bf97b8bc26ef9 [file] [log] [blame]
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000026#if defined(NUM_CHANNELS)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +000027
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000028/** This OpenCL kernel performs Winograd filter transform 3x3 when the data format is NCHW and the output tile is 2x2
29 *
30 * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
31 *
32 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
33 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
34 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
35 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
36 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
37 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
38 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
39 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
40 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
41 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
42 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
43 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
44 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
45 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
46 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
47 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
48 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
49 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
50 */
51__kernel void winograd_filter_transform_2x2_3x3_nchw(
52 TENSOR4D_DECLARATION(src),
53 TENSOR3D_DECLARATION(dst))
54{
55 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
56
57 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
58
59 // Load the values from the input tensor
60 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
61 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
62 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
63
64 // Transform the 3x3 tile in a 4x4 tile
65 float4 out0 = 0.0f;
66 float4 out1 = 0.0f;
67 float4 out2 = 0.0f;
68 float4 out3 = 0.0f;
69
70 // Row 0
71 out0.s0 = (w0.s0);
72 out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
73 out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
74 out0.s3 = (w0.s2);
75
76 // Row 1
77 out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
78 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
79 out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
80 out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
81
82 // Row 2
83 out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
84 out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
85 out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
86 out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
87
88 // Row 3
89 out3.s0 = (w2.s0);
90 out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
91 out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
92 out3.s3 = (w2.s2);
93
94 int z = get_global_id(2);
95 int x0 = z / NUM_CHANNELS; // idx filter
96 int y0 = z % NUM_CHANNELS; // idx channel
97
98 // Get output address
99 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
100
101 // Store the 16 values across the 16 channels
102 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
103 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
104 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
105 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
106 *(__global float *)(dst_addr + 4 * dst_stride_z) = out1.s0;
107 *(__global float *)(dst_addr + 5 * dst_stride_z) = out1.s1;
108 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s2;
109 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s3;
110 *(__global float *)(dst_addr + 8 * dst_stride_z) = out2.s0;
111 *(__global float *)(dst_addr + 9 * dst_stride_z) = out2.s1;
112 *(__global float *)(dst_addr + 10 * dst_stride_z) = out2.s2;
113 *(__global float *)(dst_addr + 11 * dst_stride_z) = out2.s3;
114 *(__global float *)(dst_addr + 12 * dst_stride_z) = out3.s0;
115 *(__global float *)(dst_addr + 13 * dst_stride_z) = out3.s1;
116 *(__global float *)(dst_addr + 14 * dst_stride_z) = out3.s2;
117 *(__global float *)(dst_addr + 15 * dst_stride_z) = out3.s3;
118}
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000119
120/** This OpenCL kernel performs Winograd filter transform 3x3 when the data format is NCHW and the output tile is 4x4
121 *
122 * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
123 *
124 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
125 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
126 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
127 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
128 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
129 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
130 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
131 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
132 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
133 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
134 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
135 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
136 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
137 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
138 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
139 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
140 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
141 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
142 */
143__kernel void winograd_filter_transform_4x4_3x3_nchw(
144 TENSOR4D_DECLARATION(src),
145 TENSOR3D_DECLARATION(dst))
146{
147 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
148
149 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
150
151 // Load the values from the input tensor
152 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
153 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
154 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
155
156 // Transform the 3x3 tile in a 6x6 tile
157 float8 out0 = 0.0f;
158 float8 out1 = 0.0f;
159 float8 out2 = 0.0f;
160 float8 out3 = 0.0f;
161 float8 out4 = 0.0f;
162 float8 out5 = 0.0f;
163
164 // Row 0
165 out0.s0 = (w0.s0) / 16.f;
166 out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
167 out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100168 out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
169 out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000170 out0.s5 = (w0.s2) / 4.f;
171
172 // Row 1
173 out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
174 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
175 out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100176 out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
177 out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000178 out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
179
180 // Row 2
181 out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
182 out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
183 out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100184 out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
185 out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000186 out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
187
188 // Row 3
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100189 out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
190 out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
191 out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
192 out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
193 out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
194 out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000195
196 // Row 4
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100197 out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
198 out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
199 out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
200 out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
201 out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
202 out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000203
204 // Row 5
205 out5.s0 = (w2.s0) / 4.f;
206 out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
207 out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100208 out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
209 out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000210 out5.s5 = (w2.s2);
211
212 int z = get_global_id(2);
213 int x0 = z / NUM_CHANNELS; // idx filter
214 int y0 = z % NUM_CHANNELS; // idx channel
215
216 // Get output address
217 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
218
219 // Store the 36 values across the 36 channels
220 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
221 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
222 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
223 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
224 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
225 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
226 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s0;
227 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s1;
228 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s2;
229 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s3;
230 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s4;
231 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s5;
232 *(__global float *)(dst_addr + 12 * dst_stride_z) = out2.s0;
233 *(__global float *)(dst_addr + 13 * dst_stride_z) = out2.s1;
234 *(__global float *)(dst_addr + 14 * dst_stride_z) = out2.s2;
235 *(__global float *)(dst_addr + 15 * dst_stride_z) = out2.s3;
236 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s4;
237 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s5;
238 *(__global float *)(dst_addr + 18 * dst_stride_z) = out3.s0;
239 *(__global float *)(dst_addr + 19 * dst_stride_z) = out3.s1;
240 *(__global float *)(dst_addr + 20 * dst_stride_z) = out3.s2;
241 *(__global float *)(dst_addr + 21 * dst_stride_z) = out3.s3;
242 *(__global float *)(dst_addr + 22 * dst_stride_z) = out3.s4;
243 *(__global float *)(dst_addr + 23 * dst_stride_z) = out3.s5;
244 *(__global float *)(dst_addr + 24 * dst_stride_z) = out4.s0;
245 *(__global float *)(dst_addr + 25 * dst_stride_z) = out4.s1;
246 *(__global float *)(dst_addr + 26 * dst_stride_z) = out4.s2;
247 *(__global float *)(dst_addr + 27 * dst_stride_z) = out4.s3;
248 *(__global float *)(dst_addr + 28 * dst_stride_z) = out4.s4;
249 *(__global float *)(dst_addr + 29 * dst_stride_z) = out4.s5;
250 *(__global float *)(dst_addr + 30 * dst_stride_z) = out5.s0;
251 *(__global float *)(dst_addr + 31 * dst_stride_z) = out5.s1;
252 *(__global float *)(dst_addr + 32 * dst_stride_z) = out5.s2;
253 *(__global float *)(dst_addr + 33 * dst_stride_z) = out5.s3;
254 *(__global float *)(dst_addr + 34 * dst_stride_z) = out5.s4;
255 *(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
256}
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100257
258/** This OpenCL kernel performs Winograd filter transform 5x5 when the data format is NCHW and the output tile is 4x4
259 *
260 * @note The number of channels must be passed at compile time using -DNUM_CHANNELS: e.g. -DNUM_CHANNELS=64
261 *
262 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
263 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
264 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
265 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
266 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
267 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
268 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
269 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
270 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
271 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
272 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
273 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
274 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
275 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
276 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
277 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
278 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
279 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
280 */
281__kernel void winograd_filter_transform_4x4_5x5_nchw(
282 TENSOR4D_DECLARATION(src),
283 TENSOR3D_DECLARATION(dst))
284{
285 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, NUM_CHANNELS);
286
287 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
288
289 // Load the values from the input tensor
290 const char stride_x = 4 * sizeof(float); // Used for accessing the last value in each row
291 const uchar8 stride_y = (uchar8)(0, 1, 2, 3, 4, 0, 0, 0) * (uchar8)src_stride_y;
292
293 float4 w00 = vload4(0, (__global float *)(src_addr + stride_y.s0));
294 float w01 = *((__global float *)(src_addr + stride_y.s0 + stride_x));
295 float4 w10 = vload4(0, (__global float *)(src_addr + stride_y.s1));
296 float w11 = *((__global float *)(src_addr + stride_y.s1 + stride_x));
297 float4 w20 = vload4(0, (__global float *)(src_addr + stride_y.s2));
298 float w21 = *((__global float *)(src_addr + stride_y.s2 + stride_x));
299 float4 w30 = vload4(0, (__global float *)(src_addr + stride_y.s3));
300 float w31 = *((__global float *)(src_addr + stride_y.s3 + stride_x));
301 float4 w40 = vload4(0, (__global float *)(src_addr + stride_y.s4));
302 float w41 = *((__global float *)(src_addr + stride_y.s4 + stride_x));
303
304 // Transform the 3x3 tile in a 8x8 tile
305 float8 out0 = 0.0f;
306 float8 out1 = 0.0f;
307 float8 out2 = 0.0f;
308 float8 out3 = 0.0f;
309 float8 out4 = 0.0f;
310 float8 out5 = 0.0f;
311 float8 out6 = 0.0f;
312 float8 out7 = 0.0f;
313
314 // Row 0
315 out0.s0 = w00.s0;
316 out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
317 out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
318 out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
319 out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
320 out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
321 out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
322 out0.s7 = w01;
323
324 // Row 1
325 out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
326 out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
327 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
328 out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
329 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
330 out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
331 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
332 out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
333 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
334 out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
335 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
336 out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
337 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
338 out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
339
340 // Row 2
341 out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
342 out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
343 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
344 out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
345 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
346 out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
347 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
348 out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
349 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
350 out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
351 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
352 out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
353 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
354 out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
355
356 // Row 3
357 out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
358 out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
359 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
360 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
361 out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
362 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
363 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
364 out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
365 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
366 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
367 out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
368 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
369 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
370 out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
371 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
372 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
373 out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
374 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
375 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
376 out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
377
378 // Row 4
379 out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
380 out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
381 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
382 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
383 out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
384 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
385 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
386 out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
387 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
388 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
389 out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
390 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
391 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
392 out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
393 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
394 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
395 out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
396 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
397 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
398 out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
399
400 // Row 5
401 out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
402 out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
403 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
404 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
405 out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
406 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
407 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
408 out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s0 + w40.s1) + 4.f *
409 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
410 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
411 out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s0 + w40.s1) + 4.f *
412 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
413 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
414 out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s0 + w40.s1) + 4.f *
415 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
416 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
417 out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s0 + w40.s1) + 4.f *
418 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
419 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
420 out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
421
422 // Row 6
423 out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
424 out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
425 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
426 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
427 out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
428 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
429 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
430 out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s0 + w40.s1) + 4.f *
431 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
432 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
433 out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s0 + w40.s1) + 4.f *
434 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
435 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
436 out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s0 + w40.s1) + 4.f *
437 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
438 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
439 out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s0 + w40.s1) + 4.f *
440 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
441 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
442 out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
443
444 // Row 7
445 out7.s0 = w40.s0;
446 out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
447 out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
448 out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
449 out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
450 out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
451 out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
452 out7.s7 = w41;
453
454 int z = get_global_id(2);
455 int x0 = z / NUM_CHANNELS; // idx filter
456 int y0 = z % NUM_CHANNELS; // idx channel
457
458 // Get output address
459 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
460
461 // Store the 64 values across the 64 channels
462 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
463 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
464 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
465 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
466 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
467 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
468 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
469 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
470 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
471 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
472 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
473 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
474 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
475 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
476 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
477 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
478 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
479 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
480 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
481 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
482 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
483 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
484 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
485 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
486 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
487 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
488 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
489 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
490 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
491 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
492 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
493 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
494 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
495 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
496 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
497 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
498 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
499 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
500 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
501 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
502 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
503 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
504 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
505 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
506 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
507 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
508 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
509 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
510 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
511 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
512 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
513 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
514 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
515 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
516 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
517 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
518 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
519 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
520 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
521 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
522 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
523 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
524 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
525 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
526}
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000527#endif // defined(NUM_CHANNELS)
528
529#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +0000530/** This OpenCL kernel computes the input transform when the kernel size is 3x3 and the output tile is 2x2
531 *
532 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
533 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
534 *
535 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
536 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
537 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
538 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
539 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
540 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
541 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
542 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
543 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
544 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
545 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
546 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
547 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
548 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
549 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
550 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
551 */
552__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
553 TENSOR3D_DECLARATION(src),
554 TENSOR3D_DECLARATION(dst))
555{
556 int x = get_global_id(0);
557 int y = get_global_id(1);
558 int z = get_global_id(2);
559
560 // Compute input address
561 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 2 * src_stride_x + y * 2 * src_stride_y + z * src_stride_z;
562
563 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
564
565 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
566 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
567 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
568 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
569
570 float4 tmp0 = in_row0 - in_row2;
571 float4 tmp1 = in_row1 + in_row2;
572 float4 tmp2 = in_row2 - in_row1;
573 float4 tmp3 = in_row1 - in_row3;
574
575 float out00 = tmp0.s0 - tmp0.s2;
576 float out01 = tmp0.s1 + tmp0.s2;
577 float out02 = tmp0.s2 - tmp0.s1;
578 float out03 = tmp0.s1 - tmp0.s3;
579
580 float out10 = tmp1.s0 - tmp1.s2;
581 float out11 = tmp1.s1 + tmp1.s2;
582 float out12 = tmp1.s2 - tmp1.s1;
583 float out13 = tmp1.s1 - tmp1.s3;
584
585 float out20 = tmp2.s0 - tmp2.s2;
586 float out21 = tmp2.s1 + tmp2.s2;
587 float out22 = tmp2.s2 - tmp2.s1;
588 float out23 = tmp2.s1 - tmp2.s3;
589
590 float out30 = tmp3.s0 - tmp3.s2;
591 float out31 = tmp3.s1 + tmp3.s2;
592 float out32 = tmp3.s2 - tmp3.s1;
593 float out33 = tmp3.s1 - tmp3.s3;
594
595 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
596
597 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out00;
598 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out01;
599 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out02;
600 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out03;
601 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out10;
602 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out11;
603 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out12;
604 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out13;
605 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out20;
606 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out21;
607 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out22;
608 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out23;
609 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out30;
610 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out31;
611 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out32;
612 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out33;
613}
614
615/** This OpenCL kernel computes the input transform when the kernel size is 3x3, the output tile is 2x2 and the number of channels is multiple of 2
616 *
617 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
618 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
619 *
620 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
621 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
622 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
623 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
624 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
625 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
626 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
627 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
628 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
629 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
630 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
631 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
632 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
633 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
634 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
635 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
636 */
637__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
638 TENSOR3D_DECLARATION(src),
639 TENSOR3D_DECLARATION(dst))
640{
641 int x = get_global_id(0);
642 int y = get_global_id(1);
643 int z = get_global_id(2) * 2;
644
645 // Compute input address
646 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 2 * src_stride_x + y * 2 * src_stride_y + z * src_stride_z;
647
648 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
649
650 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
651 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
652 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
653 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
654
655 src_addr += src_stride_z;
656 float4 in_row4 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
657 float4 in_row5 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
658 float4 in_row6 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
659 float4 in_row7 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
660
661 float4 tmp0 = in_row0 - in_row2;
662 float4 tmp1 = in_row1 + in_row2;
663 float4 tmp2 = in_row2 - in_row1;
664 float4 tmp3 = in_row1 - in_row3;
665
666 float4 tmp4 = in_row4 - in_row6;
667 float4 tmp5 = in_row5 + in_row6;
668 float4 tmp6 = in_row6 - in_row5;
669 float4 tmp7 = in_row5 - in_row7;
670
671 float2 out00 = (float2)(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
672 float2 out01 = (float2)(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
673 float2 out02 = (float2)(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
674 float2 out03 = (float2)(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
675
676 float2 out10 = (float2)(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
677 float2 out11 = (float2)(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
678 float2 out12 = (float2)(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
679 float2 out13 = (float2)(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
680
681 float2 out20 = (float2)(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
682 float2 out21 = (float2)(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
683 float2 out22 = (float2)(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
684 float2 out23 = (float2)(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
685
686 float2 out30 = (float2)(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
687 float2 out31 = (float2)(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
688 float2 out32 = (float2)(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
689 float2 out33 = (float2)(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
690
691 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
692
693 vstore2(out00, 0, (__global float *)(dst_addr + 0 * dst_stride_z));
694 vstore2(out01, 0, (__global float *)(dst_addr + 1 * dst_stride_z));
695 vstore2(out02, 0, (__global float *)(dst_addr + 2 * dst_stride_z));
696 vstore2(out03, 0, (__global float *)(dst_addr + 3 * dst_stride_z));
697 vstore2(out10, 0, (__global float *)(dst_addr + 4 * dst_stride_z));
698 vstore2(out11, 0, (__global float *)(dst_addr + 5 * dst_stride_z));
699 vstore2(out12, 0, (__global float *)(dst_addr + 6 * dst_stride_z));
700 vstore2(out13, 0, (__global float *)(dst_addr + 7 * dst_stride_z));
701 vstore2(out20, 0, (__global float *)(dst_addr + 8 * dst_stride_z));
702 vstore2(out21, 0, (__global float *)(dst_addr + 9 * dst_stride_z));
703 vstore2(out22, 0, (__global float *)(dst_addr + 10 * dst_stride_z));
704 vstore2(out23, 0, (__global float *)(dst_addr + 11 * dst_stride_z));
705 vstore2(out30, 0, (__global float *)(dst_addr + 12 * dst_stride_z));
706 vstore2(out31, 0, (__global float *)(dst_addr + 13 * dst_stride_z));
707 vstore2(out32, 0, (__global float *)(dst_addr + 14 * dst_stride_z));
708 vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z));
709}
Giorgio Arenafe5ef382018-04-17 10:14:10 +0100710
711#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
712 ({ \
713 comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
714 comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
715 comm_fact.s2 = 2.5f * tmp.s3; \
716 comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
717 comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
718 comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
719 comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
720 \
721 out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
722 out.s1 = comm_fact.s0 + comm_fact.s1; \
723 out.s2 = comm_fact.s0 - comm_fact.s1; \
724 out.s3 = comm_fact.s3 + comm_fact.s4; \
725 out.s4 = comm_fact.s4 - comm_fact.s3; \
726 out.s5 = comm_fact.s5 + comm_fact.s6; \
727 out.s6 = comm_fact.s5 - comm_fact.s6; \
728 out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
729 })
730
731/** This OpenCL kernel computes the input transform when the kernel size is 5x5 and the output tile is 4x4
732 *
733 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
734 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
735 *
736 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
737 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
738 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
739 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
740 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
741 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
742 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
743 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
744 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
745 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
746 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
747 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
748 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
749 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
750 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
751 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
752 */
753__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
754 TENSOR3D_DECLARATION(src),
755 TENSOR3D_DECLARATION(dst))
756{
757 int x = get_global_id(0);
758 int y = get_global_id(1);
759 int z = get_global_id(2);
760
761 // Compute input address
762 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
763
764 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
765
766 // Load 8x8 input tile
767 const float8 in_row0 = vload8(0, (__global float *)(src_addr + 0 * src_stride_y));
768 const float8 in_row1 = vload8(0, (__global float *)(src_addr + 1 * src_stride_y));
769 const float8 in_row2 = vload8(0, (__global float *)(src_addr + 2 * src_stride_y));
770 const float8 in_row3 = vload8(0, (__global float *)(src_addr + 3 * src_stride_y));
771 const float8 in_row4 = vload8(0, (__global float *)(src_addr + 4 * src_stride_y));
772 const float8 in_row5 = vload8(0, (__global float *)(src_addr + 5 * src_stride_y));
773 const float8 in_row6 = vload8(0, (__global float *)(src_addr + 6 * src_stride_y));
774 const float8 in_row7 = vload8(0, (__global float *)(src_addr + 7 * src_stride_y));
775
776 // Calculate common factors for intermediate tensor
777 float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
778 float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
779 float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
780
781 // Calculate intermediate tensor and reuse common factor vectors
782 const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
783 const float8 tmp1 = comm_fact0 + comm_fact1;
784 const float8 tmp2 = comm_fact0 - comm_fact1;
785
786 comm_fact0 = 2.5f * in_row3;
787 comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
788
789 const float8 tmp3 = comm_fact1 + comm_fact2;
790 const float8 tmp4 = comm_fact2 - comm_fact1;
791
792 comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
793 comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
794
795 const float8 tmp5 = comm_fact1 + comm_fact2;
796 const float8 tmp6 = comm_fact2 - comm_fact1;
797 const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
798
799 // Calculate output rows (reuse comm_fact0 vector)
800 float8 out0, out1, out2, out3, out4, out5, out6, out7;
801
802 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
803 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
804 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
805 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
806 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
807 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
808 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
809 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
810
811 // Store values across the 64 channels
812 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
813
814 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
815 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
816 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
817 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
818 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
819 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
820 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
821 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
822 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
823 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
824 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
825 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
826 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
827 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
828 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
829 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
830 *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
831 *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
832 *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
833 *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
834 *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
835 *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
836 *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
837 *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
838 *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
839 *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
840 *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
841 *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
842 *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
843 *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
844 *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
845 *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
846 *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
847 *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
848 *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
849 *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
850 *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
851 *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
852 *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
853 *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
854 *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
855 *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
856 *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
857 *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
858 *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
859 *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
860 *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
861 *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
862 *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
863 *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
864 *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
865 *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
866 *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
867 *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
868 *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
869 *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
870 *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
871 *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
872 *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
873 *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
874 *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
875 *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
876 *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
877 *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
878}
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000879#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000880
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000881#if defined(NUM_TILES_X)
882/** This OpenCL kernel performs Winograd output transform when the output tile is 2x2, the filter size 3x3 and the data format is NCHW
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000883 *
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000884 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000885 *
886 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
887 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
888 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
889 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
890 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
891 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
892 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000893 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
894 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
895 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
896 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
897 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
898 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
899 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
900 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
901 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
902 */
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000903__kernel void winograd_output_transform_2x2_3x3_nchw(
904 TENSOR3D_DECLARATION(src),
905 TENSOR3D_DECLARATION(dst)
906#if defined(HAS_BIAS)
907 ,
908 VECTOR_DECLARATION(bias)
909#endif // defined(HAS_BIAS)
910)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000911{
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000912 // Each thread stores a 2x2 tile
913 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000914
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000915 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000916
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000917 // Load the values across the 16 channels to compose the 4x4 tile
918 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
919 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
920 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
921 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000922
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000923 float d10 = *((__global float *)(src_addr + 4 * src_stride_z));
924 float d11 = *((__global float *)(src_addr + 5 * src_stride_z));
925 float d12 = *((__global float *)(src_addr + 6 * src_stride_z));
926 float d13 = *((__global float *)(src_addr + 7 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000927
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000928 float d20 = *((__global float *)(src_addr + 8 * src_stride_z));
929 float d21 = *((__global float *)(src_addr + 9 * src_stride_z));
930 float d22 = *((__global float *)(src_addr + 10 * src_stride_z));
931 float d23 = *((__global float *)(src_addr + 11 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000932
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000933 float d30 = *((__global float *)(src_addr + 12 * src_stride_z));
934 float d31 = *((__global float *)(src_addr + 13 * src_stride_z));
935 float d32 = *((__global float *)(src_addr + 14 * src_stride_z));
936 float d33 = *((__global float *)(src_addr + 15 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000937
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000938 // Compute the 2x2 output tile
939 float k0 = d01 + d11 + d21;
940 float k1 = d02 + d12 + d22;
941 float k2 = d11 - d21 - d31;
942 float k3 = d12 - d22 - d32;
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000943
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000944 // out00 = d00 + d10 + d20 + d01 + d11 + d21 + d02 + d12 + d22
945 // out01 = d01 + d11 + d21 - (d02 + d12 + d22) - (d03 + d13 + d23)
946 // out10 = d10 - d20 - d30 + (d11 - d21 - d31) + (d12 - d22 - d32)
947 // out11 = d11 - d21 - d31 - (d12 - d22 - d32) - (d13 - d23 - d33)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000948
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000949 float out00 = d10;
950 float out01 = -d13;
951 float out10 = d10;
952 float out11 = -d13;
953
954 out00 += d00 + d20 + k0 + k1;
955 out01 += k0 - k1 - (d03 + d23);
956 out10 += -d20 - d30 + k2 + k3;
957 out11 += k2 - k3 + d23 + d33;
958
959 int y_in = get_global_id(1);
960 int x_out = (y_in % NUM_TILES_X) * 2;
961 int y_out = (y_in / NUM_TILES_X) * 2;
962 int z_out = get_global_id(0);
963
964#if defined(HAS_BIAS)
965 // Add bias
966 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
967
968 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
969
970 out00 += (float)b;
971 out01 += (float)b;
972 out10 += (float)b;
973 out11 += (float)b;
974#endif // defined(HAS_BIAS)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000975
976 // Get output address
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000977 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000978
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000979 // Store the 2x2 output tile
980 vstore2((float2)(out00, out01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
981 vstore2((float2)(out10, out11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +0000982}
Giorgio Arenadd038702018-04-16 11:20:11 +0100983
984#define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \
985 ({ \
986 comm_fact.s0 = d1 + d2; \
987 comm_fact.s1 = d3 + d4; \
988 comm_fact.s2 = d5 + d6; \
989 \
990 col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \
991 col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \
992 \
993 comm_fact.s0 = d1 - d2; \
994 comm_fact.s1 = d3 - d4; \
995 comm_fact.s2 = d5 - d6; \
996 \
997 col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \
998 col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \
999 })
1000
1001/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 5x5 and the data format is NCHW
1002 *
1003 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
1004 *
1005 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
1006 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
1007 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1008 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
1009 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1010 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1011 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1012 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
1013 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
1014 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1015 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1016 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1017 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1018 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1019 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
1020 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1021 */
1022__kernel void winograd_output_transform_4x4_5x5_nchw(
1023 TENSOR3D_DECLARATION(src),
1024 TENSOR3D_DECLARATION(dst)
1025#if defined(HAS_BIAS)
1026 ,
1027 VECTOR_DECLARATION(bias)
1028#endif // defined(HAS_BIAS)
1029)
1030{
1031 // Each thread stores a 4x4 tile
1032 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
1033
1034 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
1035
1036 // Load the values across the 64 channels to compose the 8x8 input tile
1037 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
1038 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
1039 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
1040 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
1041 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
1042 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
1043 float d06 = *((__global float *)(src_addr + 6 * src_stride_z));
1044 float d07 = *((__global float *)(src_addr + 7 * src_stride_z));
1045
1046 float d10 = *((__global float *)(src_addr + 8 * src_stride_z));
1047 float d11 = *((__global float *)(src_addr + 9 * src_stride_z));
1048 float d12 = *((__global float *)(src_addr + 10 * src_stride_z));
1049 float d13 = *((__global float *)(src_addr + 11 * src_stride_z));
1050 float d14 = *((__global float *)(src_addr + 12 * src_stride_z));
1051 float d15 = *((__global float *)(src_addr + 13 * src_stride_z));
1052 float d16 = *((__global float *)(src_addr + 14 * src_stride_z));
1053 float d17 = *((__global float *)(src_addr + 15 * src_stride_z));
1054
1055 float d20 = *((__global float *)(src_addr + 16 * src_stride_z));
1056 float d21 = *((__global float *)(src_addr + 17 * src_stride_z));
1057 float d22 = *((__global float *)(src_addr + 18 * src_stride_z));
1058 float d23 = *((__global float *)(src_addr + 19 * src_stride_z));
1059 float d24 = *((__global float *)(src_addr + 20 * src_stride_z));
1060 float d25 = *((__global float *)(src_addr + 21 * src_stride_z));
1061 float d26 = *((__global float *)(src_addr + 22 * src_stride_z));
1062 float d27 = *((__global float *)(src_addr + 23 * src_stride_z));
1063
1064 float d30 = *((__global float *)(src_addr + 24 * src_stride_z));
1065 float d31 = *((__global float *)(src_addr + 25 * src_stride_z));
1066 float d32 = *((__global float *)(src_addr + 26 * src_stride_z));
1067 float d33 = *((__global float *)(src_addr + 27 * src_stride_z));
1068 float d34 = *((__global float *)(src_addr + 28 * src_stride_z));
1069 float d35 = *((__global float *)(src_addr + 29 * src_stride_z));
1070 float d36 = *((__global float *)(src_addr + 30 * src_stride_z));
1071 float d37 = *((__global float *)(src_addr + 31 * src_stride_z));
1072
1073 float d40 = *((__global float *)(src_addr + 32 * src_stride_z));
1074 float d41 = *((__global float *)(src_addr + 33 * src_stride_z));
1075 float d42 = *((__global float *)(src_addr + 34 * src_stride_z));
1076 float d43 = *((__global float *)(src_addr + 35 * src_stride_z));
1077 float d44 = *((__global float *)(src_addr + 36 * src_stride_z));
1078 float d45 = *((__global float *)(src_addr + 37 * src_stride_z));
1079 float d46 = *((__global float *)(src_addr + 38 * src_stride_z));
1080 float d47 = *((__global float *)(src_addr + 39 * src_stride_z));
1081
1082 float d50 = *((__global float *)(src_addr + 40 * src_stride_z));
1083 float d51 = *((__global float *)(src_addr + 41 * src_stride_z));
1084 float d52 = *((__global float *)(src_addr + 42 * src_stride_z));
1085 float d53 = *((__global float *)(src_addr + 43 * src_stride_z));
1086 float d54 = *((__global float *)(src_addr + 44 * src_stride_z));
1087 float d55 = *((__global float *)(src_addr + 45 * src_stride_z));
1088 float d56 = *((__global float *)(src_addr + 46 * src_stride_z));
1089 float d57 = *((__global float *)(src_addr + 47 * src_stride_z));
1090
1091 float d60 = *((__global float *)(src_addr + 48 * src_stride_z));
1092 float d61 = *((__global float *)(src_addr + 49 * src_stride_z));
1093 float d62 = *((__global float *)(src_addr + 50 * src_stride_z));
1094 float d63 = *((__global float *)(src_addr + 51 * src_stride_z));
1095 float d64 = *((__global float *)(src_addr + 52 * src_stride_z));
1096 float d65 = *((__global float *)(src_addr + 53 * src_stride_z));
1097 float d66 = *((__global float *)(src_addr + 54 * src_stride_z));
1098 float d67 = *((__global float *)(src_addr + 55 * src_stride_z));
1099
1100 float d70 = *((__global float *)(src_addr + 56 * src_stride_z));
1101 float d71 = *((__global float *)(src_addr + 57 * src_stride_z));
1102 float d72 = *((__global float *)(src_addr + 58 * src_stride_z));
1103 float d73 = *((__global float *)(src_addr + 59 * src_stride_z));
1104 float d74 = *((__global float *)(src_addr + 60 * src_stride_z));
1105 float d75 = *((__global float *)(src_addr + 61 * src_stride_z));
1106 float d76 = *((__global float *)(src_addr + 62 * src_stride_z));
1107 float d77 = *((__global float *)(src_addr + 63 * src_stride_z));
1108
1109 // Compute the 8x4 intermediate tensor
1110 float4 comm_fact0, comm_fact1, comm_fact2;
1111 float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
1112
1113 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
1114 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
1115 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
1116 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
1117 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
1118 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
1119 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
1120 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
1121
1122 // Compute the 4x4 output tile
1123 comm_fact0 = tmp_col1 + tmp_col2;
1124 comm_fact1 = tmp_col3 + tmp_col4;
1125 comm_fact2 = tmp_col5 + tmp_col6;
1126
1127 float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
1128 float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
1129
1130 comm_fact0 = tmp_col1 - tmp_col2;
1131 comm_fact1 = tmp_col3 - tmp_col4;
1132 comm_fact2 = tmp_col5 - tmp_col6;
1133
1134 float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
1135 float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
1136
1137 int y_in = get_global_id(1);
1138 int x_out = (y_in % NUM_TILES_X) * 4;
1139 int y_out = (y_in / NUM_TILES_X) * 4;
1140 int z_out = get_global_id(0);
1141
1142#if defined(HAS_BIAS)
1143 // Add bias
1144 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
1145
1146 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
1147
1148 out_col0 += (float4)b;
1149 out_col1 += (float4)b;
1150 out_col2 += (float4)b;
1151 out_col3 += (float4)b;
1152#endif // defined(HAS_BIAS)
1153
1154 // Get output address
1155 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
1156
1157 // Store the 4x4 output tile
1158 *(__global float *)(dst_addr + 0 * dst_stride_x + 0 * dst_stride_y) = out_col0.s0;
1159 *(__global float *)(dst_addr + 1 * dst_stride_x + 0 * dst_stride_y) = out_col1.s0;
1160 *(__global float *)(dst_addr + 2 * dst_stride_x + 0 * dst_stride_y) = out_col2.s0;
1161 *(__global float *)(dst_addr + 3 * dst_stride_x + 0 * dst_stride_y) = out_col3.s0;
1162 *(__global float *)(dst_addr + 0 * dst_stride_x + 1 * dst_stride_y) = out_col0.s1;
1163 *(__global float *)(dst_addr + 1 * dst_stride_x + 1 * dst_stride_y) = out_col1.s1;
1164 *(__global float *)(dst_addr + 2 * dst_stride_x + 1 * dst_stride_y) = out_col2.s1;
1165 *(__global float *)(dst_addr + 3 * dst_stride_x + 1 * dst_stride_y) = out_col3.s1;
1166 *(__global float *)(dst_addr + 0 * dst_stride_x + 2 * dst_stride_y) = out_col0.s2;
1167 *(__global float *)(dst_addr + 1 * dst_stride_x + 2 * dst_stride_y) = out_col1.s2;
1168 *(__global float *)(dst_addr + 2 * dst_stride_x + 2 * dst_stride_y) = out_col2.s2;
1169 *(__global float *)(dst_addr + 3 * dst_stride_x + 2 * dst_stride_y) = out_col3.s2;
1170 *(__global float *)(dst_addr + 0 * dst_stride_x + 3 * dst_stride_y) = out_col0.s3;
1171 *(__global float *)(dst_addr + 1 * dst_stride_x + 3 * dst_stride_y) = out_col1.s3;
1172 *(__global float *)(dst_addr + 2 * dst_stride_x + 3 * dst_stride_y) = out_col2.s3;
1173 *(__global float *)(dst_addr + 3 * dst_stride_x + 3 * dst_stride_y) = out_col3.s3;
1174}
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001175#endif // defined(NUM_TILES_X)