blob: ce48d28b7480f38802606263bc3f9e52f9ed8d92 [file] [log] [blame]
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001/*
2 * Copyright (c) 2018 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "helpers.h"
25
Giorgio Arenadcb5b282018-04-25 12:07:29 +010026#if defined(SRC_DIM_Z)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +000027
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +010028/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 2x2/2x1/1x2
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000029 *
Giorgio Arenadcb5b282018-04-25 12:07:29 +010030 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +010031 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
32 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000033 *
34 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
35 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
36 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
37 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
38 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
39 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
40 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
41 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
42 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
43 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
44 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
45 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
46 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
47 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
48 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
49 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
50 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
51 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
52 */
53__kernel void winograd_filter_transform_2x2_3x3_nchw(
54 TENSOR4D_DECLARATION(src),
55 TENSOR3D_DECLARATION(dst))
56{
Giorgio Arenadcb5b282018-04-25 12:07:29 +010057 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000058
59 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
60
61 // Load the values from the input tensor
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +010062#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
63 float3 w0 = vload3(0, (__global float *)(src_addr));
64#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
65 float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
66 *((__global float *)(src_addr + 1 * src_stride_y)),
67 *((__global float *)(src_addr + 2 * src_stride_y)));
68#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000069 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
70 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
71 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +010072#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000073
74 // Row 0
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +010075 float4 out0 = 0.0f;
76 out0.s0 = (w0.s0);
77 out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f;
78 out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f;
79 out0.s3 = (w0.s2);
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000080
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +010081#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000082 // Row 1
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +010083 float4 out1 = 0.0f;
84 out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f;
85 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f;
86 out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f;
87 out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000088
89 // Row 2
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +010090 float4 out2 = 0.0f;
91 out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f;
92 out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f;
93 out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f;
94 out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +000095
96 // Row 3
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +010097 float4 out3 = 0.0f;
98 out3.s0 = (w2.s0);
99 out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f;
100 out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f;
101 out3.s3 = (w2.s2);
102#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000103
104 int z = get_global_id(2);
Giorgio Arenadcb5b282018-04-25 12:07:29 +0100105 int x0 = z / SRC_DIM_Z; // idx filter
106 int y0 = z % SRC_DIM_Z; // idx channel
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000107
108 // Get output address
109 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
110
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100111 // Store the values across the channels
112 // 16 channels for 3x3 kernels
113 // 4 channels for 3x1 or 1x3 kernels
114 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
115 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
116 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
117 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
118
119#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000120 *(__global float *)(dst_addr + 4 * dst_stride_z) = out1.s0;
121 *(__global float *)(dst_addr + 5 * dst_stride_z) = out1.s1;
122 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s2;
123 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s3;
124 *(__global float *)(dst_addr + 8 * dst_stride_z) = out2.s0;
125 *(__global float *)(dst_addr + 9 * dst_stride_z) = out2.s1;
126 *(__global float *)(dst_addr + 10 * dst_stride_z) = out2.s2;
127 *(__global float *)(dst_addr + 11 * dst_stride_z) = out2.s3;
128 *(__global float *)(dst_addr + 12 * dst_stride_z) = out3.s0;
129 *(__global float *)(dst_addr + 13 * dst_stride_z) = out3.s1;
130 *(__global float *)(dst_addr + 14 * dst_stride_z) = out3.s2;
131 *(__global float *)(dst_addr + 15 * dst_stride_z) = out3.s3;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100132#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +0000133}
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000134
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100135/** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 4x4/4x1/1x4
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000136 *
Giorgio Arenadcb5b282018-04-25 12:07:29 +0100137 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100138 * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time
139 * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000140 *
141 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
142 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
143 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
144 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
145 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
146 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
147 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
148 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
149 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
150 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
151 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
152 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
153 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
154 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
155 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
156 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
157 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
158 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
159 */
160__kernel void winograd_filter_transform_4x4_3x3_nchw(
161 TENSOR4D_DECLARATION(src),
162 TENSOR3D_DECLARATION(dst))
163{
Giorgio Arenadcb5b282018-04-25 12:07:29 +0100164 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000165
166 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
167
168 // Load the values from the input tensor
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100169#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
170 float3 w0 = vload3(0, (__global float *)(src_addr));
171#elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
172 float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)),
173 *((__global float *)(src_addr + 1 * src_stride_y)),
174 *((__global float *)(src_addr + 2 * src_stride_y)));
175#else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000176 float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y));
177 float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y));
178 float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y));
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100179#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000180
181 // Row 0
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100182 float8 out0 = 0.0f;
183 out0.s0 = (w0.s0) / 16.f;
184 out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f;
185 out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f;
186 out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
187 out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f;
188 out0.s5 = (w0.s2) / 4.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000189
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100190#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000191 // Row 1
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100192 float8 out1 = 0.0f;
193 out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f;
194 out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
195 out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f;
196 out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
197 out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f;
198 out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000199
200 // Row 2
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100201 float8 out2 = 0.0f;
202 out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f;
203 out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
204 out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f;
205 out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
206 out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f;
207 out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000208
209 // Row 3
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100210 float8 out3 = 0.0f;
211 out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
212 out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
213 out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
214 out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
215 out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
216 out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000217
218 // Row 4
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100219 float8 out4 = 0.0f;
220 out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f;
221 out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
222 out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f;
223 out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
224 out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f;
225 out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f;
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000226
227 // Row 5
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100228 float8 out5 = 0.0f;
229 out5.s0 = (w2.s0) / 4.f;
230 out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f;
231 out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f;
232 out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
233 out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f;
234 out5.s5 = (w2.s2);
235#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000236
237 int z = get_global_id(2);
Giorgio Arenadcb5b282018-04-25 12:07:29 +0100238 int x0 = z / SRC_DIM_Z; // idx filter
239 int y0 = z % SRC_DIM_Z; // idx channel
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000240
241 // Get output address
242 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
243
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100244 // Store the values across the channels
245 // 36 channels for 3x3 kernels
246 // 6 channels for 3x1 or 1x3 kernels
247 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
248 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
249 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
250 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
251 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
252 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
253
254#if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000255 *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s0;
256 *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s1;
257 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s2;
258 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s3;
259 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s4;
260 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s5;
261 *(__global float *)(dst_addr + 12 * dst_stride_z) = out2.s0;
262 *(__global float *)(dst_addr + 13 * dst_stride_z) = out2.s1;
263 *(__global float *)(dst_addr + 14 * dst_stride_z) = out2.s2;
264 *(__global float *)(dst_addr + 15 * dst_stride_z) = out2.s3;
265 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s4;
266 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s5;
267 *(__global float *)(dst_addr + 18 * dst_stride_z) = out3.s0;
268 *(__global float *)(dst_addr + 19 * dst_stride_z) = out3.s1;
269 *(__global float *)(dst_addr + 20 * dst_stride_z) = out3.s2;
270 *(__global float *)(dst_addr + 21 * dst_stride_z) = out3.s3;
271 *(__global float *)(dst_addr + 22 * dst_stride_z) = out3.s4;
272 *(__global float *)(dst_addr + 23 * dst_stride_z) = out3.s5;
273 *(__global float *)(dst_addr + 24 * dst_stride_z) = out4.s0;
274 *(__global float *)(dst_addr + 25 * dst_stride_z) = out4.s1;
275 *(__global float *)(dst_addr + 26 * dst_stride_z) = out4.s2;
276 *(__global float *)(dst_addr + 27 * dst_stride_z) = out4.s3;
277 *(__global float *)(dst_addr + 28 * dst_stride_z) = out4.s4;
278 *(__global float *)(dst_addr + 29 * dst_stride_z) = out4.s5;
279 *(__global float *)(dst_addr + 30 * dst_stride_z) = out5.s0;
280 *(__global float *)(dst_addr + 31 * dst_stride_z) = out5.s1;
281 *(__global float *)(dst_addr + 32 * dst_stride_z) = out5.s2;
282 *(__global float *)(dst_addr + 33 * dst_stride_z) = out5.s3;
283 *(__global float *)(dst_addr + 34 * dst_stride_z) = out5.s4;
284 *(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100285#endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
Giorgio Arena2d9de0a2018-03-15 17:58:20 +0000286}
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100287
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +0100288#if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
289/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 2x1
290 *
291 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
292 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
293 *
294 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
295 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
296 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
297 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
298 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
299 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
300 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
301 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
302 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
303 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
304 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
305 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
306 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
307 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
308 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
309 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
310 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
311 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
312 */
313__kernel void winograd_filter_transform_2x1_3x1_nchw(
314 TENSOR4D_DECLARATION(src),
315 TENSOR3D_DECLARATION(dst))
316{
317 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
318 src_stride_x,
319 src_step_x,
320 src_stride_y,
321 src_step_y,
322 src_stride_z,
323 src_step_z,
324 src_stride_w,
325 src_step_w,
326 src_offset_first_element_in_bytes,
327 dst_ptr,
328 dst_stride_x,
329 dst_step_x,
330 dst_stride_y,
331 dst_step_y,
332 dst_stride_z,
333 dst_step_z,
334 dst_offset_first_element_in_bytes);
335}
336
337/** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 4x1
338 *
339 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
340 * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform
341 *
342 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
343 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
344 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
345 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
346 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
347 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
348 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
349 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
350 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
351 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
352 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
353 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
354 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
355 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
356 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
357 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
358 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
359 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
360 */
361__kernel void winograd_filter_transform_4x1_3x1_nchw(
362 TENSOR4D_DECLARATION(src),
363 TENSOR3D_DECLARATION(dst))
364{
365 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
366 src_stride_x,
367 src_step_x,
368 src_stride_y,
369 src_step_y,
370 src_stride_z,
371 src_step_z,
372 src_stride_w,
373 src_step_w,
374 src_offset_first_element_in_bytes,
375 dst_ptr,
376 dst_stride_x,
377 dst_step_x,
378 dst_stride_y,
379 dst_step_y,
380 dst_stride_z,
381 dst_step_z,
382 dst_offset_first_element_in_bytes);
383}
384#endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
385
386#if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
387/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x2
388 *
389 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
390 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
391 *
392 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
393 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
394 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
395 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
396 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
397 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
398 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
399 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
400 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
401 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
402 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
403 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
404 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
405 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
406 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
407 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
408 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
409 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
410 */
411__kernel void winograd_filter_transform_1x2_1x3_nchw(
412 TENSOR4D_DECLARATION(src),
413 TENSOR3D_DECLARATION(dst))
414{
415 winograd_filter_transform_2x2_3x3_nchw(src_ptr,
416 src_stride_x,
417 src_step_x,
418 src_stride_y,
419 src_step_y,
420 src_stride_z,
421 src_step_z,
422 src_stride_w,
423 src_step_w,
424 src_offset_first_element_in_bytes,
425 dst_ptr,
426 dst_stride_x,
427 dst_step_x,
428 dst_stride_y,
429 dst_step_y,
430 dst_stride_z,
431 dst_step_z,
432 dst_offset_first_element_in_bytes);
433}
434
435/** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x4
436 *
437 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
438 * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform
439 *
440 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
441 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
442 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
443 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
444 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
445 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
446 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
447 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
448 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
449 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
450 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
451 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
452 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
453 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
454 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
455 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
456 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
457 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
458 */
459__kernel void winograd_filter_transform_1x4_1x3_nchw(
460 TENSOR4D_DECLARATION(src),
461 TENSOR3D_DECLARATION(dst))
462{
463 winograd_filter_transform_4x4_3x3_nchw(src_ptr,
464 src_stride_x,
465 src_step_x,
466 src_stride_y,
467 src_step_y,
468 src_stride_z,
469 src_step_z,
470 src_stride_w,
471 src_step_w,
472 src_offset_first_element_in_bytes,
473 dst_ptr,
474 dst_stride_x,
475 dst_step_x,
476 dst_stride_y,
477 dst_step_y,
478 dst_stride_z,
479 dst_step_z,
480 dst_offset_first_element_in_bytes);
481}
482#endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL)
483
Giorgio Arenac42f28d2018-04-26 11:33:05 +0100484/** This OpenCL kernel performs Winograd filter transform 3x3 when the data layout is NHWC and the output tile is 4x4
Giorgio Arenadcb5b282018-04-25 12:07:29 +0100485 *
486 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
487 *
488 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
489 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
490 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
491 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
492 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
493 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
494 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
495 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
496 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
497 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
498 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
499 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
500 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
501 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
502 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
503 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
504 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
505 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
506 */
507__kernel void winograd_filter_transform_4x4_3x3_nhwc(
508 TENSOR4D_DECLARATION(src),
509 TENSOR3D_DECLARATION(dst))
510{
511 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
512
513 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
514
515 // Load the values from the input tensor
516 float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
517 float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
518 float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
519 float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
520 float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
521 float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
522 float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
523 float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
524 float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
525
526 // Transform the 3x3 tile in a 6x6 tile
527 float out00, out01, out02, out03, out04, out05;
528 float out10, out11, out12, out13, out14, out15;
529 float out20, out21, out22, out23, out24, out25;
530 float out30, out31, out32, out33, out34, out35;
531 float out40, out41, out42, out43, out44, out45;
532 float out50, out51, out52, out53, out54, out55;
533
534 out00 = out01 = out02 = out03 = out04 = out05 = 0.f;
535 out10 = out11 = out12 = out13 = out14 = out15 = 0.f;
536 out20 = out21 = out22 = out23 = out24 = out25 = 0.f;
537 out30 = out31 = out32 = out33 = out34 = out35 = 0.f;
538 out40 = out41 = out42 = out43 = out44 = out45 = 0.f;
539 out50 = out51 = out52 = out53 = out54 = out55 = 0.f;
540
541 // Row 0
542 out00 = (w00) / 16.f;
543 out01 = (-w00 - w01 - w02) / 24.f;
544 out02 = (-w00 + w01 - w02) / 24.f;
545 out03 = (w00 + 2.f * w01 + 4.f * w02) / 96.f;
546 out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f;
547 out05 = (w02) / 4.f;
548
549 // Row 1
550 out10 = (-w00 - w10 - w20) / 24.f;
551 out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f;
552 out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f;
553 out13 = (-w00 - w10 - w20 + 2.f * (-w01 - w11 - w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
554 out14 = (-w00 - w10 - w20 + 2.f * (w01 + w11 + w21) + 4.f * (-w02 - w12 - w22)) / 144.f;
555 out15 = (-w02 - w12 - w22) / 6.f;
556
557 // Row 2
558 out20 = (-w00 + w10 - w20) / 24.f;
559 out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f;
560 out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f;
561 out23 = (-w00 + w10 - w20 + 2.f * (-w01 + w11 - w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
562 out24 = (-w00 + w10 - w20 + 2.f * (w01 - w11 + w21) + 4.f * (-w02 + w12 - w22)) / 144.f;
563 out25 = (-w02 + w12 - w22) / 6.f;
564
565 // Row 3
566 out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f;
567 out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
568 out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f;
569 out33 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (w01 + 2.f * w11 + 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
570 out34 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (-w01 - 2.f * w11 - 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f;
571 out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f;
572
573 // Row 4
574 out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f;
575 out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
576 out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f;
577 out43 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (w01 - 2.f * w11 + 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
578 out44 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (-w01 + 2.f * w11 - 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f;
579 out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f;
580
581 // Row 5
582 out50 = (w20) / 4.f;
583 out51 = (-w20 - w21 - w22) / 6.f;
584 out52 = (-w20 + w21 - w22) / 6.f;
585 out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f;
586 out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f;
587 out55 = (w22);
588
589 int x0 = get_global_id(2); // idx filter
590 int y0 = get_global_id(0); // idx channel
591
592 // Get output address
593 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
594
595 // Store the values across the channels
596 *(__global float *)(dst_addr + 0 * dst_stride_z) = out00;
597 *(__global float *)(dst_addr + 1 * dst_stride_z) = out01;
598 *(__global float *)(dst_addr + 2 * dst_stride_z) = out02;
599 *(__global float *)(dst_addr + 3 * dst_stride_z) = out03;
600 *(__global float *)(dst_addr + 4 * dst_stride_z) = out04;
601 *(__global float *)(dst_addr + 5 * dst_stride_z) = out05;
602 *(__global float *)(dst_addr + 6 * dst_stride_z) = out10;
603 *(__global float *)(dst_addr + 7 * dst_stride_z) = out11;
604 *(__global float *)(dst_addr + 8 * dst_stride_z) = out12;
605 *(__global float *)(dst_addr + 9 * dst_stride_z) = out13;
606 *(__global float *)(dst_addr + 10 * dst_stride_z) = out14;
607 *(__global float *)(dst_addr + 11 * dst_stride_z) = out15;
608 *(__global float *)(dst_addr + 12 * dst_stride_z) = out20;
609 *(__global float *)(dst_addr + 13 * dst_stride_z) = out21;
610 *(__global float *)(dst_addr + 14 * dst_stride_z) = out22;
611 *(__global float *)(dst_addr + 15 * dst_stride_z) = out23;
612 *(__global float *)(dst_addr + 16 * dst_stride_z) = out24;
613 *(__global float *)(dst_addr + 17 * dst_stride_z) = out25;
614 *(__global float *)(dst_addr + 18 * dst_stride_z) = out30;
615 *(__global float *)(dst_addr + 19 * dst_stride_z) = out31;
616 *(__global float *)(dst_addr + 20 * dst_stride_z) = out32;
617 *(__global float *)(dst_addr + 21 * dst_stride_z) = out33;
618 *(__global float *)(dst_addr + 22 * dst_stride_z) = out34;
619 *(__global float *)(dst_addr + 23 * dst_stride_z) = out35;
620 *(__global float *)(dst_addr + 24 * dst_stride_z) = out40;
621 *(__global float *)(dst_addr + 25 * dst_stride_z) = out41;
622 *(__global float *)(dst_addr + 26 * dst_stride_z) = out42;
623 *(__global float *)(dst_addr + 27 * dst_stride_z) = out43;
624 *(__global float *)(dst_addr + 28 * dst_stride_z) = out44;
625 *(__global float *)(dst_addr + 29 * dst_stride_z) = out45;
626 *(__global float *)(dst_addr + 30 * dst_stride_z) = out50;
627 *(__global float *)(dst_addr + 31 * dst_stride_z) = out51;
628 *(__global float *)(dst_addr + 32 * dst_stride_z) = out52;
629 *(__global float *)(dst_addr + 33 * dst_stride_z) = out53;
630 *(__global float *)(dst_addr + 34 * dst_stride_z) = out54;
631 *(__global float *)(dst_addr + 35 * dst_stride_z) = out55;
632}
Giorgio Arenac42f28d2018-04-26 11:33:05 +0100633/** This OpenCL kernel performs Winograd filter transform 5x5 when the data layout is NCHW and the output tile is 4x4
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100634 *
Giorgio Arenadcb5b282018-04-25 12:07:29 +0100635 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100636 *
637 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
638 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
639 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
640 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
641 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
642 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
643 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
644 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
645 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
646 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
647 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
648 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
649 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
650 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
651 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
652 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
653 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
654 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
655 */
656__kernel void winograd_filter_transform_4x4_5x5_nchw(
657 TENSOR4D_DECLARATION(src),
658 TENSOR3D_DECLARATION(dst))
659{
Giorgio Arenadcb5b282018-04-25 12:07:29 +0100660 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100661
662 const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0);
663
664 // Load the values from the input tensor
665 const char stride_x = 4 * sizeof(float); // Used for accessing the last value in each row
666 const uchar8 stride_y = (uchar8)(0, 1, 2, 3, 4, 0, 0, 0) * (uchar8)src_stride_y;
667
668 float4 w00 = vload4(0, (__global float *)(src_addr + stride_y.s0));
669 float w01 = *((__global float *)(src_addr + stride_y.s0 + stride_x));
670 float4 w10 = vload4(0, (__global float *)(src_addr + stride_y.s1));
671 float w11 = *((__global float *)(src_addr + stride_y.s1 + stride_x));
672 float4 w20 = vload4(0, (__global float *)(src_addr + stride_y.s2));
673 float w21 = *((__global float *)(src_addr + stride_y.s2 + stride_x));
674 float4 w30 = vload4(0, (__global float *)(src_addr + stride_y.s3));
675 float w31 = *((__global float *)(src_addr + stride_y.s3 + stride_x));
676 float4 w40 = vload4(0, (__global float *)(src_addr + stride_y.s4));
677 float w41 = *((__global float *)(src_addr + stride_y.s4 + stride_x));
678
679 // Transform the 3x3 tile in a 8x8 tile
680 float8 out0 = 0.0f;
681 float8 out1 = 0.0f;
682 float8 out2 = 0.0f;
683 float8 out3 = 0.0f;
684 float8 out4 = 0.0f;
685 float8 out5 = 0.0f;
686 float8 out6 = 0.0f;
687 float8 out7 = 0.0f;
688
689 // Row 0
690 out0.s0 = w00.s0;
691 out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f;
692 out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f;
693 out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f;
694 out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f;
695 out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f;
696 out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f;
697 out0.s7 = w01;
698
699 // Row 1
700 out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f;
701 out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) +
702 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
703 out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) -
704 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f;
705 out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f *
706 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
707 out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f *
708 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f;
709 out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f *
710 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
711 out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f *
712 (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f;
713 out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f;
714
715 // Row 2
716 out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f;
717 out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) +
718 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
719 out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) -
720 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f;
721 out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f *
722 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
723 out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f *
724 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f;
725 out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f *
726 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
727 out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f *
728 (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f;
729 out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f;
730
731 // Row 3
732 out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
733 out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
734 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
735 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
736 out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) +
737 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
738 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f;
739 out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
740 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
741 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
742 out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
743 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
744 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f;
745 out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
746 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
747 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
748 out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
749 (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) +
750 (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f;
751 out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f;
752
753 // Row 4
754 out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f;
755 out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
756 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
757 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
758 out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) +
759 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
760 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f;
761 out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
762 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
763 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
764 out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
765 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f *
766 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f;
767 out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
768 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
769 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
770 out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f *
771 (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) +
772 (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f;
773 out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f;
774
775 // Row 5
776 out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f;
777 out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
778 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
779 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
780 out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) +
781 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
782 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f;
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100783 out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100784 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
785 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100786 out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100787 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f *
788 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f;
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100789 out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100790 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
791 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100792 out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f *
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100793 (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) +
794 (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f;
795 out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f;
796
797 // Row 6
798 out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f;
799 out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
800 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
801 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
802 out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) +
803 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
804 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f;
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100805 out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100806 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
807 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100808 out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100809 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f *
810 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f;
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100811 out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100812 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
813 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
Gian Marco Iodice2213d4b2018-04-27 10:39:06 +0100814 out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f *
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100815 (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) +
816 (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f;
817 out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f;
818
819 // Row 7
820 out7.s0 = w40.s0;
821 out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f;
822 out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f;
823 out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f;
824 out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f;
825 out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f;
826 out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f;
827 out7.s7 = w41;
828
829 int z = get_global_id(2);
Giorgio Arenadcb5b282018-04-25 12:07:29 +0100830 int x0 = z / SRC_DIM_Z; // idx filter
831 int y0 = z % SRC_DIM_Z; // idx channel
Giorgio Arena9373c8b2018-04-11 19:07:17 +0100832
833 // Get output address
834 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y;
835
836 // Store the 64 values across the 64 channels
837 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
838 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
839 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
840 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
841 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
842 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
843 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
844 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
845 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
846 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
847 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
848 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
849 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
850 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
851 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
852 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
853 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
854 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
855 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
856 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
857 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
858 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
859 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
860 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
861 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
862 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
863 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
864 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
865 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
866 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
867 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
868 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
869 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
870 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
871 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
872 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
873 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
874 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
875 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
876 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
877 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
878 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
879 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
880 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
881 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
882 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
883 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
884 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
885 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
886 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
887 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
888 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
889 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
890 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
891 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
892 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
893 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
894 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
895 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
896 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
897 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
898 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
899 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
900 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
901}
Giorgio Arena80d65d82018-06-08 16:30:00 +0100902
903/** This OpenCL kernel performs Winograd filter transform 5x5 when the data layout is NHWC and the output tile is 4x4
904 *
905 * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64
906 *
907 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
908 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
909 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
910 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
911 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
912 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
913 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
914 * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes)
915 * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes)
916 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
917 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
918 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
919 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
920 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
921 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
922 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
923 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
924 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
925 */
926__kernel void winograd_filter_transform_4x4_5x5_nhwc(
927 TENSOR4D_DECLARATION(src),
928 TENSOR3D_DECLARATION(dst))
929{
930 Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z);
931
932 const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(float) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w;
933
934 // Load the values from the input tensor
935 float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y));
936 float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y));
937 float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y));
938 float w03 = *((__global float *)(src_addr + 0 * src_stride_z + 3 * src_stride_y));
939 float w04 = *((__global float *)(src_addr + 0 * src_stride_z + 4 * src_stride_y));
940 float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y));
941 float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y));
942 float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y));
943 float w13 = *((__global float *)(src_addr + 1 * src_stride_z + 3 * src_stride_y));
944 float w14 = *((__global float *)(src_addr + 1 * src_stride_z + 4 * src_stride_y));
945 float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y));
946 float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y));
947 float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y));
948 float w23 = *((__global float *)(src_addr + 2 * src_stride_z + 3 * src_stride_y));
949 float w24 = *((__global float *)(src_addr + 2 * src_stride_z + 4 * src_stride_y));
950 float w30 = *((__global float *)(src_addr + 3 * src_stride_z + 0 * src_stride_y));
951 float w31 = *((__global float *)(src_addr + 3 * src_stride_z + 1 * src_stride_y));
952 float w32 = *((__global float *)(src_addr + 3 * src_stride_z + 2 * src_stride_y));
953 float w33 = *((__global float *)(src_addr + 3 * src_stride_z + 3 * src_stride_y));
954 float w34 = *((__global float *)(src_addr + 3 * src_stride_z + 4 * src_stride_y));
955 float w40 = *((__global float *)(src_addr + 4 * src_stride_z + 0 * src_stride_y));
956 float w41 = *((__global float *)(src_addr + 4 * src_stride_z + 1 * src_stride_y));
957 float w42 = *((__global float *)(src_addr + 4 * src_stride_z + 2 * src_stride_y));
958 float w43 = *((__global float *)(src_addr + 4 * src_stride_z + 3 * src_stride_y));
959 float w44 = *((__global float *)(src_addr + 4 * src_stride_z + 4 * src_stride_y));
960
961 // Transform the 3x3 tile in a 8x8 tile
962 float8 out0 = 0.0f;
963 float8 out1 = 0.0f;
964 float8 out2 = 0.0f;
965 float8 out3 = 0.0f;
966 float8 out4 = 0.0f;
967 float8 out5 = 0.0f;
968 float8 out6 = 0.0f;
969 float8 out7 = 0.0f;
970
971 // Row 0
972 out0.s0 = w00;
973 out0.s1 = -2.f * (w00 + w01 + w02 + w03 + w04) / 9.f;
974 out0.s2 = -2.f * (w00 - w01 + w02 - w03 + w04) / 9.f;
975 out0.s3 = (w00 + 2.f * w01 + 4.f * w02 + 8.f * w03 + 16.f * w04) / 90.f;
976 out0.s4 = (w00 - 2.f * w01 + 4.f * w02 - 8.f * w03 + 16.f * w04) / 90.f;
977 out0.s5 = (16.f * w00 + 8.f * w01 + 4.f * w02 + 2.f * w03 + w04) / 180.f;
978 out0.s6 = (16.f * w00 - 8.f * w01 + 4.f * w02 - 2.f * w03 + w04) / 180.f;
979 out0.s7 = w04;
980
981 // Row 1
982 out1.s0 = -2.f * (w00 + w10 + w20 + w30 + w40) / 9.f;
983 out1.s1 = 4.f * ((w00 + w10 + w20 + w30 + w40) + (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) + (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
984 out1.s2 = 4.f * ((w00 + w10 + w20 + w30 + w40) - (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) - (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f;
985 out1.s3 = -((w00 + w10 + w20 + w30 + w40) + 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
986 (w04 + w14 + w24 + w34 + w44)) / 405.f;
987 out1.s4 = -((w00 + w10 + w20 + w30 + w40) - 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f *
988 (w04 + w14 + w24 + w34 + w44)) / 405.f;
989 out1.s5 = -(16.f * (w00 + w10 + w20 + w30 + w40) + 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 2.f * (w03 + w13 + w23 + w33 + w43) +
990 (w04 + w14 + w24 + w34 + w44)) / 810.f;
991 out1.s6 = -(16.f * (w00 + w10 + w20 + w30 + w40) - 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 2.f * (w03 + w13 + w23 + w33 + w43) +
992 (w04 + w14 + w24 + w34 + w44)) / 810.f;
993 out1.s7 = -2.f * (w04 + w14 + w24 + w34 + w44) / 9.f;
994
995 // Row 2
996 out2.s0 = -2.f * (w00 - w10 + w20 - w30 + w40) / 9.f;
997 out2.s1 = 4.f * ((w00 - w10 + w20 - w30 + w40) + (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) + (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
998 out2.s2 = 4.f * ((w00 - w10 + w20 - w30 + w40) - (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) - (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f;
999 out2.s3 = -((w00 - w10 + w20 - w30 + w40) + 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
1000 (w04 - w14 + w24 - w34 + w44)) / 405.f;
1001 out2.s4 = -((w00 - w10 + w20 - w30 + w40) - 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f *
1002 (w04 - w14 + w24 - w34 + w44)) / 405.f;
1003 out2.s5 = -(16.f * (w00 - w10 + w20 - w30 + w40) + 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 2.f * (w03 - w13 + w23 - w33 + w43) +
1004 (w04 - w14 + w24 - w34 + w44)) / 810.f;
1005 out2.s6 = -(16.f * (w00 - w10 + w20 - w30 + w40) - 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 2.f * (w03 - w13 + w23 - w33 + w43) +
1006 (w04 - w14 + w24 - w34 + w44)) / 810.f;
1007 out2.s7 = -2.f * (w04 - w14 + w24 - w34 + w44) / 9.f;
1008
1009 // Row 3
1010 out3.s0 = (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) / 90.f;
1011 out3.s1 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) +
1012 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
1013 out3.s2 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) -
1014 (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f;
1015 out3.s3 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 8.f
1016 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
1017 out3.s4 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 8.f
1018 * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f;
1019 out3.s5 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
1020 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
1021 out3.s6 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f *
1022 (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f;
1023 out3.s7 = (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44) / 90.f;
1024
1025 // Row 4
1026 out4.s0 = (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) / 90.f;
1027 out4.s1 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) +
1028 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
1029 out4.s2 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) -
1030 (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f;
1031 out4.s3 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 8.f
1032 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
1033 out4.s4 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 8.f
1034 * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f;
1035 out4.s5 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
1036 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
1037 out4.s6 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f *
1038 (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f;
1039 out4.s7 = (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44) / 90.f;
1040
1041 // Row 5
1042 out5.s0 = (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) / 180.f;
1043 out5.s1 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) +
1044 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
1045 out5.s2 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) -
1046 (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f;
1047 out5.s3 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 8.f
1048 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
1049 out5.s4 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 8.f
1050 * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f;
1051 out5.s5 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
1052 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
1053 out5.s6 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f *
1054 (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f;
1055 out5.s7 = (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44) / 180.f;
1056
1057 // Row 6
1058 out6.s0 = (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) / 180.f;
1059 out6.s1 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) +
1060 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
1061 out6.s2 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) -
1062 (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f;
1063 out6.s3 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 8.f
1064 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
1065 out6.s4 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 8.f
1066 * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f;
1067 out6.s5 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
1068 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
1069 out6.s6 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f *
1070 (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f;
1071 out6.s7 = (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44) / 180.f;
1072
1073 // Row 7
1074 out7.s0 = w40;
1075 out7.s1 = -2.f * (w40 + w41 + w42 + w43 + w44) / 9.f;
1076 out7.s2 = -2.f * (w40 - w41 + w42 - w43 + w44) / 9.f;
1077 out7.s3 = (w40 + 2.f * w41 + 4.f * w42 + 8.f * w43 + 16.f * w44) / 90.f;
1078 out7.s4 = (w40 - 2.f * w41 + 4.f * w42 - 8.f * w43 + 16.f * w44) / 90.f;
1079 out7.s5 = (16.f * w40 + 8.f * w41 + 4.f * w42 + 2.f * w43 + w44) / 180.f;
1080 out7.s6 = (16.f * w40 - 8.f * w41 + 4.f * w42 - 2.f * w43 + w44) / 180.f;
1081 out7.s7 = w44;
1082
1083 int x0 = get_global_id(2); // idx filter
1084 int y0 = get_global_id(0); // idx channel
1085
1086 // Get output address
1087 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y;
1088
1089 // Store the 64 values across the 64 channels
1090 *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0;
1091 *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1;
1092 *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2;
1093 *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3;
1094 *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4;
1095 *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5;
1096 *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6;
1097 *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7;
1098 *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0;
1099 *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1;
1100 *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2;
1101 *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3;
1102 *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4;
1103 *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5;
1104 *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6;
1105 *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7;
1106 *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0;
1107 *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1;
1108 *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2;
1109 *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3;
1110 *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4;
1111 *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5;
1112 *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6;
1113 *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7;
1114 *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0;
1115 *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1;
1116 *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2;
1117 *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3;
1118 *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4;
1119 *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5;
1120 *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6;
1121 *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7;
1122 *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0;
1123 *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1;
1124 *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2;
1125 *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3;
1126 *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4;
1127 *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5;
1128 *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6;
1129 *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7;
1130 *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0;
1131 *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1;
1132 *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2;
1133 *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3;
1134 *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4;
1135 *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5;
1136 *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6;
1137 *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7;
1138 *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0;
1139 *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1;
1140 *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2;
1141 *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3;
1142 *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4;
1143 *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5;
1144 *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6;
1145 *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7;
1146 *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0;
1147 *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1;
1148 *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2;
1149 *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3;
1150 *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4;
1151 *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5;
1152 *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6;
1153 *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7;
1154}
Giorgio Arenadcb5b282018-04-25 12:07:29 +01001155#endif // defined(SRC_DIM_Z)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00001156
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001157#if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
1158/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001159 *
1160 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1161 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001162 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1163 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1164 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1165 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001166 *
1167 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1168 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1169 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1170 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1171 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1172 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1173 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1174 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1175 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1176 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1177 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1178 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1179 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1180 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1181 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1182 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1183 */
1184__kernel void winograd_input_transform_2x2_3x3_stepz1_nchw(
1185 TENSOR3D_DECLARATION(src),
1186 TENSOR3D_DECLARATION(dst))
1187{
1188 int x = get_global_id(0);
1189 int y = get_global_id(1);
1190 int z = get_global_id(2);
1191
1192 // Compute input address
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001193 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001194
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001195 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001196
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001197#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1198 float4 in_row0 = vload4(0, (__global float *)(src_addr));
1199#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1200 float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
1201 *((__global float *)(src_addr + 1 * src_stride_y)),
1202 *((__global float *)(src_addr + 2 * src_stride_y)),
1203 *((__global float *)(src_addr + 3 * src_stride_y)));
1204#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001205 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
1206 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
1207 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
1208 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001209#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001210
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001211 float4 tmp0 = in_row0;
1212
1213#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1214 tmp0 -= in_row2;
1215#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001216
1217 float out00 = tmp0.s0 - tmp0.s2;
1218 float out01 = tmp0.s1 + tmp0.s2;
1219 float out02 = tmp0.s2 - tmp0.s1;
1220 float out03 = tmp0.s1 - tmp0.s3;
1221
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001222#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1223 float4 tmp1 = in_row1 + in_row2;
1224 float4 tmp2 = in_row2 - in_row1;
1225 float4 tmp3 = in_row1 - in_row3;
1226
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001227 float out10 = tmp1.s0 - tmp1.s2;
1228 float out11 = tmp1.s1 + tmp1.s2;
1229 float out12 = tmp1.s2 - tmp1.s1;
1230 float out13 = tmp1.s1 - tmp1.s3;
1231
1232 float out20 = tmp2.s0 - tmp2.s2;
1233 float out21 = tmp2.s1 + tmp2.s2;
1234 float out22 = tmp2.s2 - tmp2.s1;
1235 float out23 = tmp2.s1 - tmp2.s3;
1236
1237 float out30 = tmp3.s0 - tmp3.s2;
1238 float out31 = tmp3.s1 + tmp3.s2;
1239 float out32 = tmp3.s2 - tmp3.s1;
1240 float out33 = tmp3.s1 - tmp3.s3;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001241#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001242
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001243 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001244
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001245 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00;
1246 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01;
1247 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02;
1248 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03;
1249
1250#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001251 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out10;
1252 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out11;
1253 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out12;
1254 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out13;
1255 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out20;
1256 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out21;
1257 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out22;
1258 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out23;
1259 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out30;
1260 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out31;
1261 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out32;
1262 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out33;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001263#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001264}
1265
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001266/** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001267 *
1268 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1269 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001270 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1271 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1272 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1273 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001274 *
1275 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1276 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1277 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1278 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1279 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1280 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1281 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1282 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1283 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1284 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1285 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1286 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1287 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1288 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1289 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1290 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1291 */
1292__kernel void winograd_input_transform_2x2_3x3_stepz2_nchw(
1293 TENSOR3D_DECLARATION(src),
1294 TENSOR3D_DECLARATION(dst))
1295{
1296 int x = get_global_id(0);
1297 int y = get_global_id(1);
1298 int z = get_global_id(2) * 2;
1299
1300 // Compute input address
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001301 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001302
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001303 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001304
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001305#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1306 float4 in_row0 = vload4(0, (__global float *)(src_addr));
1307#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1308 float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
1309 *((__global float *)(src_addr + 1 * src_stride_y)),
1310 *((__global float *)(src_addr + 2 * src_stride_y)),
1311 *((__global float *)(src_addr + 3 * src_stride_y)));
1312#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001313 float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
1314 float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
1315 float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
1316 float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001317#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001318
1319 src_addr += src_stride_z;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001320#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
1321 float4 in_row4 = vload4(0, (__global float *)(src_addr));
1322#elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL)
1323 float4 in_row4 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
1324 *((__global float *)(src_addr + 1 * src_stride_y)),
1325 *((__global float *)(src_addr + 2 * src_stride_y)),
1326 *((__global float *)(src_addr + 3 * src_stride_y)));
1327#else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001328 float4 in_row4 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
1329 float4 in_row5 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
1330 float4 in_row6 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
1331 float4 in_row7 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001332#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001333
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001334 float4 tmp0 = in_row0;
1335 float4 tmp4 = in_row4;
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001336
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001337#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1338 tmp0 -= in_row2;
1339 tmp4 -= in_row6;
1340#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001341
1342 float2 out00 = (float2)(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2);
1343 float2 out01 = (float2)(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2);
1344 float2 out02 = (float2)(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1);
1345 float2 out03 = (float2)(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3);
1346
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001347#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1348 float4 tmp1 = in_row1 + in_row2;
1349 float4 tmp2 = in_row2 - in_row1;
1350 float4 tmp3 = in_row1 - in_row3;
1351
1352 float4 tmp5 = in_row5 + in_row6;
1353 float4 tmp6 = in_row6 - in_row5;
1354 float4 tmp7 = in_row5 - in_row7;
1355
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001356 float2 out10 = (float2)(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2);
1357 float2 out11 = (float2)(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2);
1358 float2 out12 = (float2)(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1);
1359 float2 out13 = (float2)(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3);
1360
1361 float2 out20 = (float2)(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2);
1362 float2 out21 = (float2)(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2);
1363 float2 out22 = (float2)(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1);
1364 float2 out23 = (float2)(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3);
1365
1366 float2 out30 = (float2)(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2);
1367 float2 out31 = (float2)(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2);
1368 float2 out32 = (float2)(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1);
1369 float2 out33 = (float2)(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3);
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001370#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001371
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001372 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y;
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001373
1374 vstore2(out00, 0, (__global float *)(dst_addr + 0 * dst_stride_z));
1375 vstore2(out01, 0, (__global float *)(dst_addr + 1 * dst_stride_z));
1376 vstore2(out02, 0, (__global float *)(dst_addr + 2 * dst_stride_z));
1377 vstore2(out03, 0, (__global float *)(dst_addr + 3 * dst_stride_z));
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001378
1379#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001380 vstore2(out10, 0, (__global float *)(dst_addr + 4 * dst_stride_z));
1381 vstore2(out11, 0, (__global float *)(dst_addr + 5 * dst_stride_z));
1382 vstore2(out12, 0, (__global float *)(dst_addr + 6 * dst_stride_z));
1383 vstore2(out13, 0, (__global float *)(dst_addr + 7 * dst_stride_z));
1384 vstore2(out20, 0, (__global float *)(dst_addr + 8 * dst_stride_z));
1385 vstore2(out21, 0, (__global float *)(dst_addr + 9 * dst_stride_z));
1386 vstore2(out22, 0, (__global float *)(dst_addr + 10 * dst_stride_z));
1387 vstore2(out23, 0, (__global float *)(dst_addr + 11 * dst_stride_z));
1388 vstore2(out30, 0, (__global float *)(dst_addr + 12 * dst_stride_z));
1389 vstore2(out31, 0, (__global float *)(dst_addr + 13 * dst_stride_z));
1390 vstore2(out32, 0, (__global float *)(dst_addr + 14 * dst_stride_z));
1391 vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z));
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001392#endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Giorgio Arena1f9ca1d2018-03-01 11:13:45 +00001393}
Giorgio Arenafe5ef382018-04-17 10:14:10 +01001394
Giorgio Arenac42f28d2018-04-26 11:33:05 +01001395/** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data layout is NCHW
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001396 *
1397 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1398 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001399 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
1400 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
1401 * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
1402 * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001403 *
1404 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1405 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1406 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1407 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1408 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1409 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1410 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1411 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1412 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1413 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1414 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1415 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1416 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1417 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1418 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1419 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1420 */
1421__kernel void winograd_input_transform_4x4_3x3_stepz1_nchw(
1422 TENSOR3D_DECLARATION(src),
1423 TENSOR3D_DECLARATION(dst))
1424{
1425 int x = get_global_id(0);
1426 int y = get_global_id(1);
1427 int z = get_global_id(2);
1428
1429 // Compute input address
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001430 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z;
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001431
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001432 src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y);
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001433
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001434#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1435 // Row0
1436 float4 d00 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)),
1437 *((__global float *)(src_addr + 1 * src_stride_y)),
1438 *((__global float *)(src_addr + 2 * src_stride_y)),
1439 *((__global float *)(src_addr + 3 * src_stride_y)));
1440 float2 d01 = (float2)(*((__global float *)(src_addr + 4 * src_stride_y)),
1441 *((__global float *)(src_addr + 5 * src_stride_y)));
1442#else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1443 // Row0
1444 float4 d00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y));
1445 float2 d01 = vload2(2, (__global float *)(src_addr + 0 * src_stride_y));
1446#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1447
1448 float out0 = 0.0f;
1449 float out1 = 0.0f;
1450 float out2 = 0.0f;
1451 float out3 = 0.0f;
1452 float out4 = 0.0f;
1453 float out5 = 0.0f;
1454
1455 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
1456 out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0;
1457 out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0;
1458 out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0;
1459 out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0;
1460 out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0;
1461 out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1;
1462
1463#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001464 // Row4
1465 float4 d40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y));
1466 float2 d41 = vload2(2, (__global float *)(src_addr + 4 * src_stride_y));
1467
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001468 // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001469 float k0 = d41.s0;
1470 float k1 = d41.s0;
1471 float k2 = d41.s0;
1472 float k3 = d41.s0;
1473 float k4 = d41.s0;
1474 float k5 = 0.0f;
1475
1476 k0 += 4.0f * d40.s0 - 5.0f * d40.s2;
1477 k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3;
1478 k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3;
1479 k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2;
1480 k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2;
1481 k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1;
1482
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001483 out0 += k0;
1484 out1 += k1;
1485 out2 += k2;
1486 out3 += k3;
1487 out4 += k4;
1488 out5 += k5;
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001489
1490 // Row2
1491 float4 d20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y));
1492 float2 d21 = vload2(2, (__global float *)(src_addr + 2 * src_stride_y));
1493
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001494 out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0;
1495 out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0;
1496 out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0;
1497 out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0;
1498 out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0;
1499 out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1;
1500#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
1501
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001502 // Compute destination address
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001503 __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y);
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001504
1505 uint dst_plane_stride = dst_stride_z / sizeof(float);
1506
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001507 *(dst_addr) = out0;
1508 dst_addr += dst_plane_stride;
1509 *(dst_addr) = out1;
1510 dst_addr += dst_plane_stride;
1511 *(dst_addr) = out2;
1512 dst_addr += dst_plane_stride;
1513 *(dst_addr) = out3;
1514 dst_addr += dst_plane_stride;
1515 *(dst_addr) = out4;
1516 dst_addr += dst_plane_stride;
1517 *(dst_addr) = out5;
1518 dst_addr += dst_plane_stride;
1519
1520#if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001521 float out6 = k0;
1522 float out7 = k1;
1523 float out8 = k2;
1524 float out9 = k3;
1525 float out10 = k4;
1526 float out11 = k5;
1527 float out12 = k0;
1528 float out13 = k1;
1529 float out14 = k2;
1530 float out15 = k3;
1531 float out16 = k4;
1532 float out17 = k5;
1533 float out18 = k0;
1534 float out19 = k1;
1535 float out20 = k2;
1536 float out21 = k3;
1537 float out22 = k4;
1538 float out23 = k5;
1539 float out24 = k0;
1540 float out25 = k1;
1541 float out26 = k2;
1542 float out27 = k3;
1543 float out28 = k4;
1544 float out29 = k5;
1545
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001546 // Row1
1547 float4 d10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y));
1548 float2 d11 = vload2(2, (__global float *)(src_addr + 1 * src_stride_y));
1549
1550 // Row3
1551 float4 d30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y));
1552 float2 d31 = vload2(2, (__global float *)(src_addr + 3 * src_stride_y));
1553
1554 // Compute common parts for the channels between [6, 29]
1555 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
1556 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
1557 float part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0;
1558 float part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0;
1559 float part2 = 16.0f * d20.s2 - 4.0f * d21.s0;
1560 float part3 = 16.0f * d20.s1 - 4.0f * d20.s3;
1561 float part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0;
1562 float part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3;
1563 float part6 = 4.0f * d20.s2 - 4.0f * d21.s0;
1564 float part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3;
1565 float part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0;
1566 float part9 = 8.0f * d20.s1 - 8.0f * d20.s3;
1567 float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1;
1568 float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1;
1569
1570 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
1571 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
1572 float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0;
1573 float part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0
1574 float part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0
1575 float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3;
1576 float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0;
1577 float part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3
1578 float part18 = part6 * 0.25f; // d20.s2 - d21.s0
1579 float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3;
1580 float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0;
1581 float part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3)
1582 float part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1
1583 float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1;
1584
1585 out6 += part0 - part1;
1586 out12 += part0 + part1;
1587 out7 += part2 + part3 + part4 + part5;
1588 out8 += part2 - part3 + part4 - part5;
1589 out13 += part2 + part3 - part4 - part5;
1590 out14 += part2 - part3 - part4 + part5;
1591 out9 += part6 + part7 + part8 + part9;
1592 out10 += part6 - part7 + part8 - part9;
1593 out15 += part6 - part7 - part8 + part9;
1594 out16 += part6 + part7 - part8 - part9;
1595 out11 += part10 + part11;
1596 out17 += part10 - part11;
1597
1598 out18 += part13 - part12;
1599 out24 += part13 + part12;
1600 out19 += part14 + part15 + part16 + part17;
1601 out20 += part14 - part15 + part16 - part17;
1602 out25 += part14 - part15 - part16 + part17;
1603 out26 += part14 + part15 - part16 - part17;
1604 out21 += part18 + part19 + part20 + part21;
1605 out22 += part18 - part19 + part20 - part21;
1606 out27 += part18 - part19 - part20 + part21;
1607 out28 += part18 + part19 - part20 - part21;
1608 out23 += part22 + part23;
1609 out29 += part22 - part23;
1610
1611 *(dst_addr) = out6;
1612 dst_addr += dst_plane_stride;
1613 *(dst_addr) = out7;
1614 dst_addr += dst_plane_stride;
1615 *(dst_addr) = out8;
1616 dst_addr += dst_plane_stride;
1617 *(dst_addr) = out9;
1618 dst_addr += dst_plane_stride;
1619 *(dst_addr) = out10;
1620 dst_addr += dst_plane_stride;
1621 *(dst_addr) = out11;
1622 dst_addr += dst_plane_stride;
1623 *(dst_addr) = out12;
1624 dst_addr += dst_plane_stride;
1625 *(dst_addr) = out13;
1626 dst_addr += dst_plane_stride;
1627 *(dst_addr) = out14;
1628 dst_addr += dst_plane_stride;
1629 *(dst_addr) = out15;
1630 dst_addr += dst_plane_stride;
1631 *(dst_addr) = out16;
1632 dst_addr += dst_plane_stride;
1633 *(dst_addr) = out17;
1634 dst_addr += dst_plane_stride;
1635
1636 *(dst_addr) = out18;
1637 dst_addr += dst_plane_stride;
1638 *(dst_addr) = out19;
1639 dst_addr += dst_plane_stride;
1640 *(dst_addr) = out20;
1641 dst_addr += dst_plane_stride;
1642 *(dst_addr) = out21;
1643 dst_addr += dst_plane_stride;
1644 *(dst_addr) = out22;
1645 dst_addr += dst_plane_stride;
1646 *(dst_addr) = out23;
1647 dst_addr += dst_plane_stride;
1648 *(dst_addr) = out24;
1649 dst_addr += dst_plane_stride;
1650 *(dst_addr) = out25;
1651 dst_addr += dst_plane_stride;
1652 *(dst_addr) = out26;
1653 dst_addr += dst_plane_stride;
1654 *(dst_addr) = out27;
1655 dst_addr += dst_plane_stride;
1656 *(dst_addr) = out28;
1657 dst_addr += dst_plane_stride;
1658 *(dst_addr) = out29;
1659 dst_addr += dst_plane_stride;
1660
1661 // Row5
1662 float4 d50 = vload4(0, (__global float *)(src_addr + 5 * src_stride_y));
1663 float2 d51 = vload2(2, (__global float *)(src_addr + 5 * src_stride_y));
1664
1665 // Channels [30, 35]
1666 out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
1667 out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
1668 out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
1669 out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
1670 out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0;
1671 out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1;
1672
1673 *(dst_addr) = out0;
1674 dst_addr += dst_plane_stride;
1675 *(dst_addr) = out1;
1676 dst_addr += dst_plane_stride;
1677 *(dst_addr) = out2;
1678 dst_addr += dst_plane_stride;
1679 *(dst_addr) = out3;
1680 dst_addr += dst_plane_stride;
1681 *(dst_addr) = out4;
1682 dst_addr += dst_plane_stride;
1683 *(dst_addr) = out5;
1684 dst_addr += dst_plane_stride;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01001685#endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01001686}
1687
Giorgio Arenac42f28d2018-04-26 11:33:05 +01001688#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
1689/** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data layout is NHWC
1690 *
1691 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
1692 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Giorgio Arenabe39f122018-06-08 17:50:38 +01001693 * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112)
1694 * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112)
Giorgio Arenac42f28d2018-04-26 11:33:05 +01001695 *
1696 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
1697 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
1698 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
1699 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
1700 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
1701 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
1702 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
1703 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
1704 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
1705 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
1706 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
1707 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
1708 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
1709 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
1710 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
1711 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
1712 */
1713__kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc(
1714 TENSOR3D_DECLARATION(src),
1715 TENSOR3D_DECLARATION(dst))
1716{
1717 int x = get_global_id(0);
1718 int y = get_global_id(1);
1719 int z = get_global_id(2);
1720
1721 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * src_stride_x;
1722
1723 // Clamp coordinates. This clamp is valid for all rows
1724 int4 y_coord0 = (int4)(y * 4) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT;
1725 int2 y_coord1 = (int2)(y * 4) + (int2)(4, 5) - (int2)PAD_LEFT;
1726 y_coord0 = clamp(y_coord0, -1, SRC_DIM_1);
1727 y_coord1 = clamp(y_coord1, -1, SRC_DIM_1);
1728
1729 // Row4
1730 int z_coord = (z * 4) - PAD_TOP + 4;
1731
1732 // If z < 0, set y to -1
1733 int4 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
1734 int2 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
1735 // If z >= SRC_DIM_2, set y to SRC_DIM_2
1736 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
1737 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
1738
1739 // Clamp z coordinate
1740 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1741
1742 float d40 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1743 float d41 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1744 float d42 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1745 float d43 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1746 float d44 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1747 float d45 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
1748
1749 float k0 = d44;
1750 float k1 = d44;
1751 float k2 = d44;
1752 float k3 = d44;
1753 float k4 = d44;
1754 float k5 = (float)0.0f;
1755
1756 k0 += 4.0f * d40 - 5.0f * d42;
1757 k1 += -4.0f * d41 - 4.0f * d42 + d43;
1758 k2 += 4.0f * d41 - 4.0f * d42 - d43;
1759 k3 += -2.0f * d41 + 2.0f * d43 - d42;
1760 k4 += 2.0f * d41 - 2.0f * d43 - d42;
1761 k5 += 4.0f * d41 - 5.0f * d43 + d45;
1762
1763 // Row0
1764 z_coord = (z * 4) - PAD_TOP + 0;
1765
1766#if PAD_TOP != 0
1767 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
1768 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
1769 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
1770 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
1771 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1772#else // PAD_TOP != 0
1773 valid_y0 = y_coord0;
1774 valid_y1 = y_coord1;
1775#endif // if PAD_TOP == 0, we cannot read out of bound
1776
1777 float d00 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1778 float d01 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1779 float d02 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1780 float d03 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1781 float d04 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1782 float d05 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
1783
1784 // Row2
1785 z_coord = (z * 4) - PAD_TOP + 2;
1786 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
1787 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
1788 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
1789 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
1790 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1791
1792 float d20 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1793 float d21 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1794 float d22 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1795 float d23 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1796 float d24 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1797 float d25 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
1798
1799 // Compute destination address
1800 __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + x * dst_stride_x + (y + z * (int)NUM_TILES_X) * dst_stride_y);
1801
1802 uint dst_plane_stride = dst_stride_z / sizeof(float);
1803
1804 float out0 = k0;
1805 float out1 = k1;
1806 float out2 = k2;
1807 float out3 = k3;
1808 float out4 = k4;
1809 float out5 = k5;
1810 float out6 = k0;
1811 float out7 = k1;
1812 float out8 = k2;
1813 float out9 = k3;
1814 float out10 = k4;
1815 float out11 = k5;
1816 float out12 = k0;
1817 float out13 = k1;
1818 float out14 = k2;
1819 float out15 = k3;
1820 float out16 = k4;
1821 float out17 = k5;
1822 float out18 = k0;
1823 float out19 = k1;
1824 float out20 = k2;
1825 float out21 = k3;
1826 float out22 = k4;
1827 float out23 = k5;
1828 float out24 = k0;
1829 float out25 = k1;
1830 float out26 = k2;
1831 float out27 = k3;
1832 float out28 = k4;
1833 float out29 = k5;
1834
1835 // Channels [0, 5]: [out00, out01, out02, out03, out04, out05]
1836 out0 += 16.0f * d00 - 20.0f * d02 - 20.0f * d20 + 25.0f * d22 + 4.0f * d04 - 5.0f * d24;
1837 out1 += -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 20.0f * d21 + 20.0f * d22 - 5.0f * d23 + 4.0f * d04 - 5.0f * d24;
1838 out2 += 16.0f * d01 - 16.0f * d02 - 4.0f * d03 - 20.0f * d21 + 20.0f * d22 + 5.0f * d23 + 4.0f * d04 - 5.0f * d24;
1839 out3 += -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 10.0f * d21 + 5.0f * d22 - 10.0f * d23 + 4.0f * d04 - 5.0f * d24;
1840 out4 += 8.0f * d01 - 4.0f * d02 - 8.0f * d03 - 10.0f * d21 + 5.0f * d22 + 10.0f * d23 + 4.0f * d04 - 5.0f * d24;
1841 out5 += 16.0f * d01 - 20.0f * d03 - 20.0f * d21 + 4.0f * d05 + 25.0f * d23 - 5.0f * d25;
1842
1843 *((__global float *)dst_addr) = out0;
1844 dst_addr += dst_plane_stride;
1845 *((__global float *)dst_addr) = out1;
1846 dst_addr += dst_plane_stride;
1847 *((__global float *)dst_addr) = out2;
1848 dst_addr += dst_plane_stride;
1849 *((__global float *)dst_addr) = out3;
1850 dst_addr += dst_plane_stride;
1851 *((__global float *)dst_addr) = out4;
1852 dst_addr += dst_plane_stride;
1853 *((__global float *)dst_addr) = out5;
1854 dst_addr += dst_plane_stride;
1855
1856 // Row1
1857 z_coord = (z * 4) - PAD_TOP + 1;
1858 // Row1 can never be out of bounds
1859 valid_y0 = y_coord0;
1860 valid_y1 = y_coord1;
1861
1862 float d10 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1863 float d11 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1864 float d12 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1865 float d13 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1866 float d14 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1867 float d15 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
1868
1869 // Row3
1870 z_coord = (z * 4) - PAD_TOP + 3;
1871 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
1872 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
1873 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
1874 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
1875 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1876 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1877
1878 float d30 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
1879 float d31 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
1880 float d32 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
1881 float d33 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
1882 float d34 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
1883 float d35 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
1884
1885 // Compute common parts for the channels between [6, 29]
1886 // Channels [6, 11]: [out10, out11, out12, out13, out14, out15]
1887 // Channels [12, 17]: [out20, out21, out22, out23, out24, out25]
1888 float part0 = -16.0f * d20 + 20.0f * d22 - 4.0f * d24;
1889 float part1 = 16.0f * d10 - 20.0f * d12 + 4.0f * d14 - 4.0f * d30 + 5.0f * d32 - d34;
1890 float part2 = 16.0f * d22 - 4.0f * d24;
1891 float part3 = 16.0f * d21 - 4.0f * d23;
1892 float part4 = 16.0f * d12 - 4.0f * d14 - 4.0f * d32 + d34;
1893 float part5 = 16.0f * d11 - 4.0f * d13 - 4.0f * d31 + d33;
1894 float part6 = 4.0f * d22 - 4.0f * d24;
1895 float part7 = 8.0f * d11 - 8.0f * d13 - 2.0f * d31 + 2.0f * d33;
1896 float part8 = 4.0f * d12 - 4.0f * d14 - d32 + d34;
1897 float part9 = 8.0f * d21 - 8.0f * d23;
1898 float part10 = -16.0f * d21 + 20.0f * d23 - 4.0f * d25;
1899 float part11 = -16.0f * d11 + 20.0f * d13 - 4.0f * d15 + 4.0f * d31 - 5.0f * d33 + d35;
1900
1901 // Channels [18, 23]: [out30, out31, out32, out33, out34, out35]
1902 // Channels [24, 29]: [out40, out41, out42, out43, out44, out45]
1903 float part12 = 8.0f * d10 - 10.0f * d12 + 2.0f * d14 - 8.0f * d30 + 10.0f * d32 - 2.0f * d34;
1904 float part13 = part0 * 0.25f; // -4.0f * d20 + 5.0f * d22 - d24
1905 float part14 = part2 * 0.25f; // 4.0f * d22 - d24
1906 float part15 = 8.0f * d11 - 2.0f * d13 - 8.0f * d31 + 2.0f * d33;
1907 float part16 = 8.0f * d12 - 2.0f * d14 - 8.0f * d32 + 2.0f * d34;
1908 float part17 = part3 * 0.25f; // 4.0f * d21 - d23
1909 float part18 = part6 * 0.25f; // d22 - d24
1910 float part19 = 4.0f * d11 - 4.0f * d13 - 4.0f * d31 + 4.0f * d33;
1911 float part20 = 2.0f * d12 - 2.0f * d14 - 2.0f * d32 + 2.0f * d34;
1912 float part21 = part9 * 0.25f; // 2.0f * (d21 - d23)
1913 float part22 = part10 * 0.25f; // - 4.0f * d21 + 5.0f * d23 - d25
1914 float part23 = part11 * 0.5f + 6.0f * d31 - 7.5f * d33 + 1.5f * d35; // - 8.0f * d11 + 10.0f * d13 - 2.0f * d15 + 8.0f * d31 - 10.0f * d33 + 2.0f * d35;
1915
1916 out6 += part0 - part1;
1917 out12 += part0 + part1;
1918 out7 += part2 + part3 + part4 + part5;
1919 out8 += part2 - part3 + part4 - part5;
1920 out13 += part2 + part3 - part4 - part5;
1921 out14 += part2 - part3 - part4 + part5;
1922 out9 += part6 + part7 + part8 + part9;
1923 out10 += part6 - part7 + part8 - part9;
1924 out15 += part6 - part7 - part8 + part9;
1925 out16 += part6 + part7 - part8 - part9;
1926 out11 += part10 + part11;
1927 out17 += part10 - part11;
1928
1929 out18 += part13 - part12;
1930 out24 += part13 + part12;
1931 out19 += part14 + part15 + part16 + part17;
1932 out20 += part14 - part15 + part16 - part17;
1933 out25 += part14 - part15 - part16 + part17;
1934 out26 += part14 + part15 - part16 - part17;
1935 out21 += part18 + part19 + part20 + part21;
1936 out22 += part18 - part19 + part20 - part21;
1937 out27 += part18 - part19 - part20 + part21;
1938 out28 += part18 + part19 - part20 - part21;
1939 out23 += part22 + part23;
1940 out29 += part22 - part23;
1941
1942 *((__global float *)dst_addr) = out6;
1943 dst_addr += dst_plane_stride;
1944 *((__global float *)dst_addr) = out7;
1945 dst_addr += dst_plane_stride;
1946 *((__global float *)dst_addr) = out8;
1947 dst_addr += dst_plane_stride;
1948 *((__global float *)dst_addr) = out9;
1949 dst_addr += dst_plane_stride;
1950 *((__global float *)dst_addr) = out10;
1951 dst_addr += dst_plane_stride;
1952 *((__global float *)dst_addr) = out11;
1953 dst_addr += dst_plane_stride;
1954 *((__global float *)dst_addr) = out12;
1955 dst_addr += dst_plane_stride;
1956 *((__global float *)dst_addr) = out13;
1957 dst_addr += dst_plane_stride;
1958 *((__global float *)dst_addr) = out14;
1959 dst_addr += dst_plane_stride;
1960 *((__global float *)dst_addr) = out15;
1961 dst_addr += dst_plane_stride;
1962 *((__global float *)dst_addr) = out16;
1963 dst_addr += dst_plane_stride;
1964 *((__global float *)dst_addr) = out17;
1965 dst_addr += dst_plane_stride;
1966
1967 *((__global float *)dst_addr) = out18;
1968 dst_addr += dst_plane_stride;
1969 *((__global float *)dst_addr) = out19;
1970 dst_addr += dst_plane_stride;
1971 *((__global float *)dst_addr) = out20;
1972 dst_addr += dst_plane_stride;
1973 *((__global float *)dst_addr) = out21;
1974 dst_addr += dst_plane_stride;
1975 *((__global float *)dst_addr) = out22;
1976 dst_addr += dst_plane_stride;
1977 *((__global float *)dst_addr) = out23;
1978 dst_addr += dst_plane_stride;
1979 *((__global float *)dst_addr) = out24;
1980 dst_addr += dst_plane_stride;
1981 *((__global float *)dst_addr) = out25;
1982 dst_addr += dst_plane_stride;
1983 *((__global float *)dst_addr) = out26;
1984 dst_addr += dst_plane_stride;
1985 *((__global float *)dst_addr) = out27;
1986 dst_addr += dst_plane_stride;
1987 *((__global float *)dst_addr) = out28;
1988 dst_addr += dst_plane_stride;
1989 *((__global float *)dst_addr) = out29;
1990 dst_addr += dst_plane_stride;
1991
1992 // Row5
1993 z_coord = (z * 4) - PAD_TOP + 5;
1994 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0);
1995 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0);
1996 valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2);
1997 valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2);
1998 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
1999 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
2000
2001 float d50 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z);
2002 float d51 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z);
2003 float d52 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z);
2004 float d53 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z);
2005 float d54 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z);
2006 float d55 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z);
2007
2008 // Channels [30, 35]
2009 out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34;
2010 out1 = -16.0f * d11 - 16.0f * d12 + 4.0f * d13 + 20.0f * d31 + 20.0f * d32 - 5.0f * d33 - 4.0f * d51 - 4.0f * d52 + d53 + d54 + 4.0f * d14 - 5.0f * d34;
2011 out2 = 16.0f * d11 - 16.0f * d12 - 4.0f * d13 - 20.0f * d31 + 20.0f * d32 + 5.0f * d33 + 4.0f * d51 - 4.0f * d52 - d53 + d54 + 4.0f * d14 - 5.0f * d34;
2012 out3 = -8.0f * d11 - 4.0f * d12 + 8.0f * d13 + 10.0f * d31 - 10.0f * d33 + 5.0f * d32 - 2.0f * d51 + 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
2013 out4 = 8.0f * d11 - 4.0f * d12 - 8.0f * d13 - 10.0f * d31 + 5.0f * d32 + 10.0f * d33 + 2.0f * d51 - 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34;
2014 out5 = 16.0f * d11 - 20.0f * d13 + 4.0f * d15 - 20.0f * d31 + 25.0f * d33 - 5.0f * d35 + 4.0f * d51 - 5.0f * d53 + d55;
2015
2016 *((__global float *)dst_addr) = out0;
2017 dst_addr += dst_plane_stride;
2018 *((__global float *)dst_addr) = out1;
2019 dst_addr += dst_plane_stride;
2020 *((__global float *)dst_addr) = out2;
2021 dst_addr += dst_plane_stride;
2022 *((__global float *)dst_addr) = out3;
2023 dst_addr += dst_plane_stride;
2024 *((__global float *)dst_addr) = out4;
2025 dst_addr += dst_plane_stride;
2026 *((__global float *)dst_addr) = out5;
2027 dst_addr += dst_plane_stride;
2028}
2029
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002030#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Giorgio Arenac42f28d2018-04-26 11:33:05 +01002031
Giorgio Arenafe5ef382018-04-17 10:14:10 +01002032#define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \
2033 ({ \
2034 comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \
2035 comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \
2036 comm_fact.s2 = 2.5f * tmp.s3; \
2037 comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \
2038 comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \
2039 comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \
2040 comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \
2041 \
2042 out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \
2043 out.s1 = comm_fact.s0 + comm_fact.s1; \
2044 out.s2 = comm_fact.s0 - comm_fact.s1; \
2045 out.s3 = comm_fact.s3 + comm_fact.s4; \
2046 out.s4 = comm_fact.s4 - comm_fact.s3; \
2047 out.s5 = comm_fact.s5 + comm_fact.s6; \
2048 out.s6 = comm_fact.s5 - comm_fact.s6; \
2049 out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \
2050 })
2051
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002052/** This OpenCL kernel computes the input transform when the kernel size is 5x5 and the output tile is 4x4 when the data layout is NCHW
Giorgio Arenafe5ef382018-04-17 10:14:10 +01002053 *
2054 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2055 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2056 *
2057 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
2058 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2059 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2060 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2061 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2062 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2063 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2064 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2065 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2066 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2067 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2068 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2069 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2070 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2071 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2072 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2073 */
2074__kernel void winograd_input_transform_4x4_5x5_stepz1_nchw(
2075 TENSOR3D_DECLARATION(src),
2076 TENSOR3D_DECLARATION(dst))
2077{
2078 int x = get_global_id(0);
2079 int y = get_global_id(1);
2080 int z = get_global_id(2);
2081
2082 // Compute input address
2083 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z;
2084
2085 src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y);
2086
2087 // Load 8x8 input tile
2088 const float8 in_row0 = vload8(0, (__global float *)(src_addr + 0 * src_stride_y));
2089 const float8 in_row1 = vload8(0, (__global float *)(src_addr + 1 * src_stride_y));
2090 const float8 in_row2 = vload8(0, (__global float *)(src_addr + 2 * src_stride_y));
2091 const float8 in_row3 = vload8(0, (__global float *)(src_addr + 3 * src_stride_y));
2092 const float8 in_row4 = vload8(0, (__global float *)(src_addr + 4 * src_stride_y));
2093 const float8 in_row5 = vload8(0, (__global float *)(src_addr + 5 * src_stride_y));
2094 const float8 in_row6 = vload8(0, (__global float *)(src_addr + 6 * src_stride_y));
2095 const float8 in_row7 = vload8(0, (__global float *)(src_addr + 7 * src_stride_y));
2096
2097 // Calculate common factors for intermediate tensor
2098 float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
2099 float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
2100 float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
2101
2102 // Calculate intermediate tensor and reuse common factor vectors
2103 const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
2104 const float8 tmp1 = comm_fact0 + comm_fact1;
2105 const float8 tmp2 = comm_fact0 - comm_fact1;
2106
2107 comm_fact0 = 2.5f * in_row3;
2108 comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
2109
2110 const float8 tmp3 = comm_fact1 + comm_fact2;
2111 const float8 tmp4 = comm_fact2 - comm_fact1;
2112
2113 comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
2114 comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
2115
2116 const float8 tmp5 = comm_fact1 + comm_fact2;
2117 const float8 tmp6 = comm_fact2 - comm_fact1;
2118 const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
2119
2120 // Calculate output rows (reuse comm_fact0 vector)
2121 float8 out0, out1, out2, out3, out4, out5, out6, out7;
2122
2123 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
2124 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
2125 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
2126 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
2127 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
2128 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
2129 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
2130 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
2131
2132 // Store values across the 64 channels
2133 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y;
2134
2135 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
2136 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
2137 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
2138 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
2139 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
2140 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
2141 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
2142 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
2143 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
2144 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
2145 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
2146 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
2147 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
2148 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
2149 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
2150 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
2151 *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
2152 *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
2153 *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
2154 *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
2155 *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
2156 *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
2157 *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
2158 *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
2159 *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
2160 *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
2161 *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
2162 *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
2163 *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
2164 *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
2165 *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
2166 *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
2167 *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
2168 *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
2169 *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
2170 *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
2171 *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
2172 *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
2173 *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
2174 *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
2175 *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
2176 *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
2177 *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
2178 *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
2179 *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
2180 *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
2181 *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
2182 *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
2183 *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
2184 *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
2185 *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
2186 *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
2187 *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
2188 *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
2189 *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
2190 *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
2191 *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
2192 *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
2193 *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
2194 *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
2195 *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
2196 *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
2197 *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
2198 *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
2199}
Giorgio Arenabe39f122018-06-08 17:50:38 +01002200
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002201#if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
2202/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1
Giorgio Arenabe39f122018-06-08 17:50:38 +01002203 *
2204 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2205 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002206 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
2207 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2208 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
2209 *
2210 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
2211 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2212 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2213 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2214 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2215 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2216 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2217 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2218 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2219 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2220 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2221 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2222 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2223 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2224 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2225 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2226 */
2227__kernel void winograd_input_transform_2x1_3x1_stepz1_nchw(
2228 TENSOR3D_DECLARATION(src),
2229 TENSOR3D_DECLARATION(dst))
2230{
2231 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
2232 src_stride_x,
2233 src_step_x,
2234 src_stride_y,
2235 src_step_y,
2236 src_stride_z,
2237 src_step_z,
2238 src_offset_first_element_in_bytes,
2239 dst_ptr,
2240 dst_stride_x,
2241 dst_step_x,
2242 dst_stride_y,
2243 dst_step_y,
2244 dst_stride_z,
2245 dst_step_z,
2246 dst_offset_first_element_in_bytes);
2247}
2248
2249/** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2
2250 *
2251 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2252 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2253 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
2254 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2255 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
2256 *
2257 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
2258 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2259 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2260 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2261 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2262 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2263 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2264 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2265 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2266 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2267 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2268 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2269 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2270 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2271 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2272 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2273 */
2274__kernel void winograd_input_transform_2x1_3x1_stepz2_nchw(
2275 TENSOR3D_DECLARATION(src),
2276 TENSOR3D_DECLARATION(dst))
2277{
2278 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
2279 src_stride_x,
2280 src_step_x,
2281 src_stride_y,
2282 src_step_y,
2283 src_stride_z,
2284 src_step_z,
2285 src_offset_first_element_in_bytes,
2286 dst_ptr,
2287 dst_stride_x,
2288 dst_step_x,
2289 dst_stride_y,
2290 dst_step_y,
2291 dst_stride_z,
2292 dst_step_z,
2293 dst_offset_first_element_in_bytes);
2294}
2295
2296/** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1
2297 *
2298 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2299 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2300 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
2301 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
2302 * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
2303 *
2304 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
2305 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2306 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2307 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2308 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2309 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2310 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2311 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2312 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2313 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2314 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2315 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2316 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2317 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2318 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2319 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2320 */
2321__kernel void winograd_input_transform_4x1_3x1_stepz1_nchw(
2322 TENSOR3D_DECLARATION(src),
2323 TENSOR3D_DECLARATION(dst))
2324{
2325 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
2326 src_stride_x,
2327 src_step_x,
2328 src_stride_y,
2329 src_step_y,
2330 src_stride_z,
2331 src_step_z,
2332 src_offset_first_element_in_bytes,
2333 dst_ptr,
2334 dst_stride_x,
2335 dst_step_x,
2336 dst_stride_y,
2337 dst_step_y,
2338 dst_stride_z,
2339 dst_step_z,
2340 dst_offset_first_element_in_bytes);
2341}
2342#endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL)
2343
2344#if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
2345/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2
2346 *
2347 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2348 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2349 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2350 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2351 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
2352 *
2353 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
2354 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2355 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2356 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2357 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2358 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2359 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2360 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2361 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2362 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2363 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2364 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2365 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2366 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2367 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2368 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2369 */
2370__kernel void winograd_input_transform_1x2_1x3_stepz1_nchw(
2371 TENSOR3D_DECLARATION(src),
2372 TENSOR3D_DECLARATION(dst))
2373{
2374 winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr,
2375 src_stride_x,
2376 src_step_x,
2377 src_stride_y,
2378 src_step_y,
2379 src_stride_z,
2380 src_step_z,
2381 src_offset_first_element_in_bytes,
2382 dst_ptr,
2383 dst_stride_x,
2384 dst_step_x,
2385 dst_stride_y,
2386 dst_step_y,
2387 dst_stride_z,
2388 dst_step_z,
2389 dst_offset_first_element_in_bytes);
2390}
2391
2392/** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2
2393 *
2394 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2395 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2396 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2397 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2398 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
2399 *
2400 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
2401 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2402 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2403 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2404 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2405 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2406 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2407 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2408 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2409 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2410 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2411 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2412 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2413 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2414 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2415 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2416 */
2417__kernel void winograd_input_transform_1x2_1x3_stepz2_nchw(
2418 TENSOR3D_DECLARATION(src),
2419 TENSOR3D_DECLARATION(dst))
2420{
2421 winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr,
2422 src_stride_x,
2423 src_step_x,
2424 src_stride_y,
2425 src_step_y,
2426 src_stride_z,
2427 src_step_z,
2428 src_offset_first_element_in_bytes,
2429 dst_ptr,
2430 dst_stride_x,
2431 dst_step_x,
2432 dst_stride_y,
2433 dst_step_y,
2434 dst_stride_z,
2435 dst_step_z,
2436 dst_offset_first_element_in_bytes);
2437}
2438
2439/** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4
2440 *
2441 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2442 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2443 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
2444 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2445 * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time
2446 *
2447 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
2448 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2449 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2450 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2451 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2452 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2453 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2454 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2455 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2456 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2457 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2458 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2459 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2460 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2461 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2462 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2463 */
2464__kernel void winograd_input_transform_1x4_1x3_stepz1_nchw(
2465 TENSOR3D_DECLARATION(src),
2466 TENSOR3D_DECLARATION(dst))
2467{
2468 winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr,
2469 src_stride_x,
2470 src_step_x,
2471 src_stride_y,
2472 src_step_y,
2473 src_stride_z,
2474 src_step_z,
2475 src_offset_first_element_in_bytes,
2476 dst_ptr,
2477 dst_stride_x,
2478 dst_step_x,
2479 dst_stride_y,
2480 dst_step_y,
2481 dst_stride_z,
2482 dst_step_z,
2483 dst_offset_first_element_in_bytes);
2484}
2485#endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL)
2486
2487#if defined(SRC_DIM_1) && defined(SRC_DIM_2)
2488/** This OpenCL kernel computes the input transform when the kernel size is 5x5 and the output tile is 4x4 when the data layout is NHWC
2489 *
2490 * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5).
2491 * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0).
2492 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
2493 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
Giorgio Arenabe39f122018-06-08 17:50:38 +01002494 *
2495 * @param[in] src_ptr Pointer to the source image. Supported data types: F32
2496 * @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
2497 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2498 * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes)
2499 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2500 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image
2501 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2502 * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes)
2503 * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr
2504 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2505 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2506 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2507 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2508 * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes)
2509 * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes)
2510 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2511 */
2512__kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc(
2513 TENSOR3D_DECLARATION(src),
2514 TENSOR3D_DECLARATION(dst))
2515{
2516 int x = get_global_id(0);
2517 int y = get_global_id(1);
2518 int z = get_global_id(2);
2519
2520 // Compute input address
2521 __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float);
2522
2523 // Clamp coordinates. This clamp is valid for all rows
2524 int8 y_coord = (int8)(y * 4) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT;
2525 y_coord = clamp(y_coord, -1, SRC_DIM_1);
2526
2527 // Load 8x8 input tile
2528 float8 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7;
2529
2530 // Row0
2531 int z_coord = (z * 4) - PAD_TOP + 0;
2532 int8 valid_y = select(y_coord, -1, (int8)z_coord < 0); // If z < 0, set y to -1
2533 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2
2534 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); // Clamp z coordinate
2535
2536 in_row0.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
2537 in_row0.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
2538 in_row0.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
2539 in_row0.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
2540 in_row0.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
2541 in_row0.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
2542 in_row0.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
2543 in_row0.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
2544
2545 // Row1
2546 z_coord = (z * 4) - PAD_TOP + 1;
2547 valid_y = select(y_coord, -1, (int8)z_coord < 0);
2548 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
2549 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
2550
2551 in_row1.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
2552 in_row1.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
2553 in_row1.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
2554 in_row1.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
2555 in_row1.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
2556 in_row1.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
2557 in_row1.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
2558 in_row1.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
2559
2560 // Row2
2561 z_coord = (z * 4) - PAD_TOP + 2;
2562 valid_y = select(y_coord, -1, (int8)z_coord < 0);
2563 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
2564 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
2565
2566 in_row2.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
2567 in_row2.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
2568 in_row2.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
2569 in_row2.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
2570 in_row2.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
2571 in_row2.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
2572 in_row2.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
2573 in_row2.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
2574
2575 // Row3
2576 z_coord = (z * 4) - PAD_TOP + 3;
2577 valid_y = select(y_coord, -1, (int8)z_coord < 0);
2578 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
2579 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
2580
2581 in_row3.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
2582 in_row3.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
2583 in_row3.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
2584 in_row3.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
2585 in_row3.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
2586 in_row3.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
2587 in_row3.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
2588 in_row3.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
2589
2590 // Row4
2591 z_coord = (z * 4) - PAD_TOP + 4;
2592 valid_y = select(y_coord, -1, (int8)z_coord < 0);
2593 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
2594 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
2595
2596 in_row4.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
2597 in_row4.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
2598 in_row4.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
2599 in_row4.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
2600 in_row4.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
2601 in_row4.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
2602 in_row4.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
2603 in_row4.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
2604
2605 // Row5
2606 z_coord = (z * 4) - PAD_TOP + 5;
2607 valid_y = select(y_coord, -1, (int8)z_coord < 0);
2608 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
2609 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
2610
2611 in_row5.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
2612 in_row5.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
2613 in_row5.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
2614 in_row5.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
2615 in_row5.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
2616 in_row5.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
2617 in_row5.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
2618 in_row5.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
2619
2620 // Row6
2621 z_coord = (z * 4) - PAD_TOP + 6;
2622 valid_y = select(y_coord, -1, (int8)z_coord < 0);
2623 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
2624 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
2625
2626 in_row6.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
2627 in_row6.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
2628 in_row6.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
2629 in_row6.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
2630 in_row6.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
2631 in_row6.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
2632 in_row6.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
2633 in_row6.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
2634
2635 // Row7
2636 z_coord = (z * 4) - PAD_TOP + 7;
2637 valid_y = select(y_coord, -1, (int8)z_coord < 0);
2638 valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2);
2639 z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1);
2640
2641 in_row7.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z);
2642 in_row7.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z);
2643 in_row7.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z);
2644 in_row7.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z);
2645 in_row7.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z);
2646 in_row7.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z);
2647 in_row7.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z);
2648 in_row7.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z);
2649
2650 // Calculate common factors for intermediate tensor
2651 float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4;
2652 float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3;
2653 float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6;
2654
2655 // Calculate intermediate tensor and reuse common factor vectors
2656 const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2;
2657 const float8 tmp1 = comm_fact0 + comm_fact1;
2658 const float8 tmp2 = comm_fact0 - comm_fact1;
2659
2660 comm_fact0 = 2.5f * in_row3;
2661 comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5;
2662
2663 const float8 tmp3 = comm_fact1 + comm_fact2;
2664 const float8 tmp4 = comm_fact2 - comm_fact1;
2665
2666 comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5;
2667 comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6;
2668
2669 const float8 tmp5 = comm_fact1 + comm_fact2;
2670 const float8 tmp6 = comm_fact2 - comm_fact1;
2671 const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5;
2672
2673 // Calculate output rows (reuse comm_fact0 vector)
2674 float8 out0, out1, out2, out3, out4, out5, out6, out7;
2675
2676 OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0);
2677 OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0);
2678 OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0);
2679 OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0);
2680 OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0);
2681 OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0);
2682 OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0);
2683 OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0);
2684
2685 // Store values across the 64 channels
2686 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(float) + (y + z * (int)NUM_TILES_X) * dst_stride_y;
2687
2688 *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0;
2689 *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1;
2690 *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2;
2691 *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3;
2692 *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4;
2693 *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5;
2694 *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6;
2695 *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7;
2696 *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0;
2697 *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1;
2698 *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2;
2699 *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3;
2700 *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4;
2701 *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5;
2702 *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6;
2703 *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7;
2704 *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0;
2705 *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1;
2706 *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2;
2707 *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3;
2708 *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4;
2709 *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5;
2710 *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6;
2711 *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7;
2712 *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0;
2713 *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1;
2714 *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2;
2715 *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3;
2716 *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4;
2717 *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5;
2718 *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6;
2719 *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7;
2720 *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0;
2721 *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1;
2722 *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2;
2723 *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3;
2724 *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4;
2725 *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5;
2726 *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6;
2727 *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7;
2728 *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0;
2729 *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1;
2730 *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2;
2731 *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3;
2732 *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4;
2733 *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5;
2734 *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6;
2735 *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7;
2736 *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0;
2737 *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1;
2738 *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2;
2739 *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3;
2740 *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4;
2741 *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5;
2742 *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6;
2743 *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7;
2744 *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0;
2745 *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1;
2746 *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2;
2747 *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3;
2748 *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4;
2749 *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5;
2750 *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6;
2751 *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7;
2752}
2753#endif // defined(SRC_DIM_1) && defined(SRC_DIM_2)
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002754#endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002755
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002756#if defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)
2757/** This OpenCL kernel performs Winograd output transform when the output tile is 2x2/2x1 or 1x2, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002758 *
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002759 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002760 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
2761 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
2762 * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
2763 * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002764 *
2765 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
2766 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
2767 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2768 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
2769 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2770 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2771 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002772 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
2773 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
2774 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2775 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2776 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2777 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2778 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2779 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2780 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2781 */
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002782__kernel void winograd_output_transform_2x2_3x3_nchw(
2783 TENSOR3D_DECLARATION(src),
2784 TENSOR3D_DECLARATION(dst)
2785#if defined(HAS_BIAS)
2786 ,
2787 VECTOR_DECLARATION(bias)
2788#endif // defined(HAS_BIAS)
2789)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002790{
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002791 // Each thread stores a 2x2/2x1 or 1x2 tile accordingly with the filter size
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002792 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002793
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002794 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002795
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002796 // Load the values across the 16 or 4 channels to compose the 4x4 or 4x1 tile
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002797 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
2798 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
2799 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
2800 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002801
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002802#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
2803 // Compute the 2x1 or 1x2 output tile
2804 // out00 = d00 + d01 + d02
2805 // out01 = d01 - d02 - d03
2806
2807 float out00 = d00 + d01 + d02;
2808 float out01 = d01 - d02 - d03;
2809#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
2810 float d10 = *((__global float *)(src_addr + 4 * src_stride_z));
2811 float d11 = *((__global float *)(src_addr + 5 * src_stride_z));
2812 float d12 = *((__global float *)(src_addr + 6 * src_stride_z));
2813 float d13 = *((__global float *)(src_addr + 7 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002814
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002815 float d20 = *((__global float *)(src_addr + 8 * src_stride_z));
2816 float d21 = *((__global float *)(src_addr + 9 * src_stride_z));
2817 float d22 = *((__global float *)(src_addr + 10 * src_stride_z));
2818 float d23 = *((__global float *)(src_addr + 11 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002819
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002820 float d30 = *((__global float *)(src_addr + 12 * src_stride_z));
2821 float d31 = *((__global float *)(src_addr + 13 * src_stride_z));
2822 float d32 = *((__global float *)(src_addr + 14 * src_stride_z));
2823 float d33 = *((__global float *)(src_addr + 15 * src_stride_z));
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002824
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002825 // Compute the 2x2 output tile
2826 float k0 = d01 + d11 + d21;
2827 float k1 = d02 + d12 + d22;
2828 float k2 = d11 - d21 - d31;
2829 float k3 = d12 - d22 - d32;
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002830
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002831 // out00 = d00 + d10 + d20 + d01 + d11 + d21 + d02 + d12 + d22
2832 // out01 = d01 + d11 + d21 - (d02 + d12 + d22) - (d03 + d13 + d23)
2833 // out10 = d10 - d20 - d30 + (d11 - d21 - d31) + (d12 - d22 - d32)
2834 // out11 = d11 - d21 - d31 - (d12 - d22 - d32) - (d13 - d23 - d33)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002835
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002836 float out00 = d10;
2837 float out01 = -d13;
2838 float out10 = d10;
2839 float out11 = -d13;
2840
2841 out00 += d00 + d20 + k0 + k1;
2842 out01 += k0 - k1 - (d03 + d23);
2843 out10 += -d20 - d30 + k2 + k3;
2844 out11 += k2 - k3 + d23 + d33;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002845#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002846
2847 int y_in = get_global_id(1);
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002848 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
2849 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002850 int z_out = get_global_id(0);
2851
2852#if defined(HAS_BIAS)
2853 // Add bias
2854 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
2855
2856 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
2857
2858 out00 += (float)b;
2859 out01 += (float)b;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002860#endif // defined(HAS_BIAS)
2861
2862 // Get output address
2863 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z;
2864
2865 // Store the output tile
2866#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
2867 *((__global float *)(dst_addr + 0 * dst_stride_y)) = out00;
2868 *((__global float *)(dst_addr + 1 * dst_stride_y)) = out01;
2869#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
2870 vstore2((float2)(out00, out01), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
2871#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
2872
2873#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
2874#if defined(HAS_BIAS)
2875 // Add bias
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002876 out10 += (float)b;
2877 out11 += (float)b;
2878#endif // defined(HAS_BIAS)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002879
Gian Marco Iodiced2fab732018-03-02 11:18:12 +00002880 vstore2((float2)(out10, out11), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002881#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
Gian Marco Iodice7e4b2392018-02-22 16:17:20 +00002882}
Giorgio Arenadd038702018-04-16 11:20:11 +01002883
Giorgio Arenac42f28d2018-04-26 11:33:05 +01002884/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data layout is NCHW
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01002885 *
2886 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002887 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
2888 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
2889 * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
2890 * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01002891 *
2892 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
2893 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
2894 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
2895 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
2896 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
2897 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2898 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2899 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
2900 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
2901 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
2902 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
2903 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
2904 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
2905 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
2906 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
2907 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
2908 */
2909__kernel void winograd_output_transform_4x4_3x3_nchw(
2910 TENSOR3D_DECLARATION(src),
2911 TENSOR3D_DECLARATION(dst)
2912#if defined(HAS_BIAS)
2913 ,
2914 VECTOR_DECLARATION(bias)
2915#endif // defined(HAS_BIAS)
2916)
2917{
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002918 // Each thread stores a 4x4/4x1 or 1x4 tile
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01002919 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
2920
2921 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
2922
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002923 // Load the values across the channels to compose the 6x6 or 6x1 tile
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01002924 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
2925 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
2926 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
2927 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
2928 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
2929 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
2930
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01002931#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
2932 // Compute out00, out01, out02 and out03
2933 float out00 = d00 + d01 + d02 + d03 + d04;
2934 float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04;
2935 float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04;
2936 float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05;
2937#else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01002938 float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
2939 float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
2940 float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
2941 float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
2942 float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
2943 float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
2944
2945 float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
2946 float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
2947 float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
2948 float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
2949 float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
2950 float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
2951
2952 float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
2953 float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
2954 float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
2955 float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
2956 float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
2957 float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
2958
2959 float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
2960 float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
2961 float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
2962 float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
2963 float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
2964 float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
2965
2966 float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
2967 float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
2968 float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
2969 float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
2970 float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
2971 float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
2972
2973 // Compute out00, out01, out02 and out03
2974 float out00 = d01 + d21 + d41 + d11 + d31;
2975 float out01 = d01 + d21 + d41 + d11 + d31;
2976 float out02 = d01 + d21 + d41 + d11 + d31;
2977 float out03 = d01 + d21 + d41 + d11 + d31;
2978
2979 float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
2980 float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
2981
2982 out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
2983 out01 += k1 - d02 - d12 - d22 - d32 - d42;
2984 out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
2985 out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
2986
2987 // Compute out10, out11, out12 and out13
2988 float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
2989 float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
2990 float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
2991 float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
2992
2993 k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
2994 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
2995
2996 out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
2997 out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
2998 out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
2999 out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
3000
3001 // Compute out20, out21, out22 and out23
3002 float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
3003 float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
3004 float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
3005 float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
3006
3007 k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
3008 k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
3009
3010 out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
3011 out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
3012 out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
3013 out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
3014
3015 // Compute out30, out31, out32 and out33
3016 float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
3017 float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
3018 float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
3019 float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
3020
3021 k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
3022 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
3023
3024 out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
3025 out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
3026 out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
3027 out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01003028#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01003029
3030 int y_in = get_global_id(1);
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01003031 int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W;
3032 int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H;
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01003033 int z_out = get_global_id(0);
3034
3035#if defined(HAS_BIAS)
3036 // Add bias
3037 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
3038
3039 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
3040
3041 out00 += (float)b;
3042 out01 += (float)b;
3043 out02 += (float)b;
3044 out03 += (float)b;
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01003045#endif // defined(HAS_BIAS)
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01003046
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01003047 // Get output address
3048 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z;
3049
3050 // Store the output tile
3051#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
3052 *((__global float *)(dst_addr + 0 * dst_stride_y)) = out00;
3053 *((__global float *)(dst_addr + 1 * dst_stride_y)) = out01;
3054 *((__global float *)(dst_addr + 2 * dst_stride_y)) = out02;
3055 *((__global float *)(dst_addr + 3 * dst_stride_y)) = out03;
3056#else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
3057 vstore4((float4)(out00, out01, out02, out03), 0, (__global float *)(dst_addr + 0 * dst_stride_y));
3058#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
3059
3060#if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
3061#if defined(HAS_BIAS)
3062 // Add bias
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01003063 out10 += (float)b;
3064 out11 += (float)b;
3065 out12 += (float)b;
3066 out13 += (float)b;
3067
3068 out20 += (float)b;
3069 out21 += (float)b;
3070 out22 += (float)b;
3071 out23 += (float)b;
3072
3073 out30 += (float)b;
3074 out31 += (float)b;
3075 out32 += (float)b;
3076 out33 += (float)b;
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01003077#endif // defined(HAS_BIAS)
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01003078 vstore4((float4)(out10, out11, out12, out13), 0, (__global float *)(dst_addr + 1 * dst_stride_y));
3079 vstore4((float4)(out20, out21, out22, out23), 0, (__global float *)(dst_addr + 2 * dst_stride_y));
3080 vstore4((float4)(out30, out31, out32, out33), 0, (__global float *)(dst_addr + 3 * dst_stride_y));
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01003081#endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
Gian Marco Iodicee52a3002018-04-11 15:59:10 +01003082}
3083
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01003084#if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
3085/** This OpenCL kernel performs Winograd output transform when the output tile is 2x1, the filter size 3x1 and the data layout is NCHW
3086 *
3087 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
3088 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2
3089 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
3090 * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
3091 *
3092 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
3093 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3094 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3095 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3096 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3097 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3098 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3099 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3100 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
3101 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3102 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3103 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3104 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3105 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3106 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3107 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3108 */
3109__kernel void winograd_output_transform_2x1_3x1_nchw(
3110 TENSOR3D_DECLARATION(src),
3111 TENSOR3D_DECLARATION(dst)
3112#if defined(HAS_BIAS)
3113 ,
3114 VECTOR_DECLARATION(bias)
3115#endif // defined(HAS_BIAS)
3116)
3117{
3118 winograd_output_transform_2x2_3x3_nchw(src_ptr,
3119 src_stride_x,
3120 src_step_x,
3121 src_stride_y,
3122 src_step_y,
3123 src_stride_z,
3124 src_step_z,
3125 src_offset_first_element_in_bytes,
3126 dst_ptr,
3127 dst_stride_x,
3128 dst_step_x,
3129 dst_stride_y,
3130 dst_step_y,
3131 dst_stride_z,
3132 dst_step_z,
3133 dst_offset_first_element_in_bytes
3134#if defined(HAS_BIAS)
3135 ,
3136 bias_ptr,
3137 bias_stride_x,
3138 bias_step_x,
3139 bias_offset_first_element_in_bytes
3140#endif // defined(HAS_BIAS)
3141 );
3142}
3143
3144/** This OpenCL kernel performs Winograd output transform when the output tile is 4x1, the filter size 3x1 and the data layout is NCHW
3145 *
3146 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
3147 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4
3148 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1
3149 * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time
3150 *
3151 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
3152 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3153 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3154 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3155 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3156 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3157 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3158 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3159 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
3160 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3161 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3162 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3163 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3164 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3165 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3166 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3167 */
3168__kernel void winograd_output_transform_4x1_3x1_nchw(
3169 TENSOR3D_DECLARATION(src),
3170 TENSOR3D_DECLARATION(dst)
3171#if defined(HAS_BIAS)
3172 ,
3173 VECTOR_DECLARATION(bias)
3174#endif // defined(HAS_BIAS)
3175)
3176{
3177 winograd_output_transform_4x4_3x3_nchw(src_ptr,
3178 src_stride_x,
3179 src_step_x,
3180 src_stride_y,
3181 src_step_y,
3182 src_stride_z,
3183 src_step_z,
3184 src_offset_first_element_in_bytes,
3185 dst_ptr,
3186 dst_stride_x,
3187 dst_step_x,
3188 dst_stride_y,
3189 dst_step_y,
3190 dst_stride_z,
3191 dst_step_z,
3192 dst_offset_first_element_in_bytes
3193#if defined(HAS_BIAS)
3194 ,
3195 bias_ptr,
3196 bias_stride_x,
3197 bias_step_x,
3198 bias_offset_first_element_in_bytes
3199#endif // defined(HAS_BIAS)
3200 );
3201}
3202#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL)
3203
3204#if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
3205/** This OpenCL kernel performs Winograd output transform when the output tile is 1x2, the filter size 1x3 and the data layout is NCHW
3206 *
3207 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
3208 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
3209 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2
3210 * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
3211 *
3212 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
3213 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3214 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3215 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3216 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3217 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3218 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3219 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3220 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
3221 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3222 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3223 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3224 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3225 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3226 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3227 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3228 */
3229__kernel void winograd_output_transform_1x2_1x3_nchw(
3230 TENSOR3D_DECLARATION(src),
3231 TENSOR3D_DECLARATION(dst)
3232#if defined(HAS_BIAS)
3233 ,
3234 VECTOR_DECLARATION(bias)
3235#endif // defined(HAS_BIAS)
3236)
3237{
3238 winograd_output_transform_2x2_3x3_nchw(src_ptr,
3239 src_stride_x,
3240 src_step_x,
3241 src_stride_y,
3242 src_step_y,
3243 src_stride_z,
3244 src_step_z,
3245 src_offset_first_element_in_bytes,
3246 dst_ptr,
3247 dst_stride_x,
3248 dst_step_x,
3249 dst_stride_y,
3250 dst_step_y,
3251 dst_stride_z,
3252 dst_step_z,
3253 dst_offset_first_element_in_bytes
3254#if defined(HAS_BIAS)
3255 ,
3256 bias_ptr,
3257 bias_stride_x,
3258 bias_step_x,
3259 bias_offset_first_element_in_bytes
3260#endif // defined(HAS_BIAS)
3261 );
3262}
3263
3264/** This OpenCL kernel performs Winograd output transform when the output tile is 1x4, the filter size 1x3 and the data layout is NCHW
3265 *
3266 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
3267 * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1
3268 * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4
3269 * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time
3270 *
3271 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
3272 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3273 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3274 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3275 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3276 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3277 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3278 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3279 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
3280 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3281 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3282 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3283 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3284 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3285 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3286 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3287 */
3288__kernel void winograd_output_transform_1x4_1x3_nchw(
3289 TENSOR3D_DECLARATION(src),
3290 TENSOR3D_DECLARATION(dst)
3291#if defined(HAS_BIAS)
3292 ,
3293 VECTOR_DECLARATION(bias)
3294#endif // defined(HAS_BIAS)
3295)
3296{
3297 winograd_output_transform_4x4_3x3_nchw(src_ptr,
3298 src_stride_x,
3299 src_step_x,
3300 src_stride_y,
3301 src_step_y,
3302 src_stride_z,
3303 src_step_z,
3304 src_offset_first_element_in_bytes,
3305 dst_ptr,
3306 dst_stride_x,
3307 dst_step_x,
3308 dst_stride_y,
3309 dst_step_y,
3310 dst_stride_z,
3311 dst_step_z,
3312 dst_offset_first_element_in_bytes
3313#if defined(HAS_BIAS)
3314 ,
3315 bias_ptr,
3316 bias_stride_x,
3317 bias_step_x,
3318 bias_offset_first_element_in_bytes
3319#endif // defined(HAS_BIAS)
3320 );
3321}
3322#endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL)
3323
Giorgio Arenac42f28d2018-04-26 11:33:05 +01003324/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data layout is NHWC
Giorgio Arena3695f9a2018-04-23 17:41:22 +01003325 *
3326 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
3327 *
3328 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
3329 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3330 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3331 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3332 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3333 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3334 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3335 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3336 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
3337 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3338 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3339 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3340 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3341 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3342 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3343 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
Georgios Pinitasc084f0d2018-06-11 17:43:31 +01003344 * @param[in] dst_size Size of the destination tensor, minus the last padding
Giorgio Arena3695f9a2018-04-23 17:41:22 +01003345 */
3346__kernel void winograd_output_transform_4x4_3x3_nhwc(
3347 TENSOR3D_DECLARATION(src),
Georgios Pinitasc084f0d2018-06-11 17:43:31 +01003348 TENSOR3D_DECLARATION(dst),
Giorgio Arena3695f9a2018-04-23 17:41:22 +01003349#if defined(HAS_BIAS)
Georgios Pinitasc084f0d2018-06-11 17:43:31 +01003350 VECTOR_DECLARATION(bias),
Giorgio Arena3695f9a2018-04-23 17:41:22 +01003351#endif // defined(HAS_BIAS)
Georgios Pinitasc084f0d2018-06-11 17:43:31 +01003352 int dst_size)
Giorgio Arena3695f9a2018-04-23 17:41:22 +01003353{
3354 // Each thread stores a 4x4 tile
3355 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3356
3357 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
3358
3359 // Load the values across the 36 channels to compose the 6x6 tile
3360 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
3361 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
3362 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
3363 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
3364 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
3365 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
3366
3367 float d10 = *((__global float *)(src_addr + 6 * src_stride_z));
3368 float d11 = *((__global float *)(src_addr + 7 * src_stride_z));
3369 float d12 = *((__global float *)(src_addr + 8 * src_stride_z));
3370 float d13 = *((__global float *)(src_addr + 9 * src_stride_z));
3371 float d14 = *((__global float *)(src_addr + 10 * src_stride_z));
3372 float d15 = *((__global float *)(src_addr + 11 * src_stride_z));
3373
3374 float d20 = *((__global float *)(src_addr + 12 * src_stride_z));
3375 float d21 = *((__global float *)(src_addr + 13 * src_stride_z));
3376 float d22 = *((__global float *)(src_addr + 14 * src_stride_z));
3377 float d23 = *((__global float *)(src_addr + 15 * src_stride_z));
3378 float d24 = *((__global float *)(src_addr + 16 * src_stride_z));
3379 float d25 = *((__global float *)(src_addr + 17 * src_stride_z));
3380
3381 float d30 = *((__global float *)(src_addr + 18 * src_stride_z));
3382 float d31 = *((__global float *)(src_addr + 19 * src_stride_z));
3383 float d32 = *((__global float *)(src_addr + 20 * src_stride_z));
3384 float d33 = *((__global float *)(src_addr + 21 * src_stride_z));
3385 float d34 = *((__global float *)(src_addr + 22 * src_stride_z));
3386 float d35 = *((__global float *)(src_addr + 23 * src_stride_z));
3387
3388 float d40 = *((__global float *)(src_addr + 24 * src_stride_z));
3389 float d41 = *((__global float *)(src_addr + 25 * src_stride_z));
3390 float d42 = *((__global float *)(src_addr + 26 * src_stride_z));
3391 float d43 = *((__global float *)(src_addr + 27 * src_stride_z));
3392 float d44 = *((__global float *)(src_addr + 28 * src_stride_z));
3393 float d45 = *((__global float *)(src_addr + 29 * src_stride_z));
3394
3395 float d50 = *((__global float *)(src_addr + 30 * src_stride_z));
3396 float d51 = *((__global float *)(src_addr + 31 * src_stride_z));
3397 float d52 = *((__global float *)(src_addr + 32 * src_stride_z));
3398 float d53 = *((__global float *)(src_addr + 33 * src_stride_z));
3399 float d54 = *((__global float *)(src_addr + 34 * src_stride_z));
3400 float d55 = *((__global float *)(src_addr + 35 * src_stride_z));
3401
3402 // Compute out00, out01, out02 and out03
3403 float out00 = d01 + d21 + d41 + d11 + d31;
3404 float out01 = d01 + d21 + d41 + d11 + d31;
3405 float out02 = d01 + d21 + d41 + d11 + d31;
3406 float out03 = d01 + d21 + d41 + d11 + d31;
3407
3408 float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44;
3409 float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44;
3410
3411 out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42;
3412 out01 += k1 - d02 - d12 - d22 - d32 - d42;
3413 out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42;
3414 out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45;
3415
3416 // Compute out10, out11, out12 and out13
3417 float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
3418 float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
3419 float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
3420 float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41;
3421
3422 k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44;
3423 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44;
3424
3425 out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42;
3426 out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42;
3427 out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42;
3428 out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45;
3429
3430 // Compute out20, out21, out22 and out23
3431 float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
3432 float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
3433 float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
3434 float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41;
3435
3436 k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44;
3437 k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44;
3438
3439 out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42;
3440 out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42;
3441 out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42;
3442 out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45;
3443
3444 // Compute out30, out31, out32 and out33
3445 float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
3446 float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
3447 float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
3448 float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51;
3449
3450 k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54;
3451 k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54;
3452
3453 out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52;
3454 out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52;
3455 out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52;
3456 out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55;
3457
3458 int y_in = get_global_id(1);
3459 int x_out = get_global_id(0);
3460 int y_out = (y_in % NUM_TILES_X) * 4;
3461 int z_out = (y_in / NUM_TILES_X) * 4;
3462
3463#if defined(HAS_BIAS)
3464 // Add bias
3465 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
3466
3467 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
3468
3469 out00 += (float)b;
3470 out01 += (float)b;
3471 out02 += (float)b;
3472 out03 += (float)b;
3473
3474 out10 += (float)b;
3475 out11 += (float)b;
3476 out12 += (float)b;
3477 out13 += (float)b;
3478
3479 out20 += (float)b;
3480 out21 += (float)b;
3481 out22 += (float)b;
3482 out23 += (float)b;
3483
3484 out30 += (float)b;
3485 out31 += (float)b;
3486 out32 += (float)b;
3487 out33 += (float)b;
3488
3489#endif // defined(HAS_BIAS)
3490
3491 // Get output address
Georgios Pinitasc084f0d2018-06-11 17:43:31 +01003492 int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z);
3493 offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
3494 int4 mult_y = min(dst_size - offset, 1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise.
Giorgio Arena3695f9a2018-04-23 17:41:22 +01003495
3496 // Store the 4x4 output tile
Georgios Pinitasc084f0d2018-06-11 17:43:31 +01003497 *((__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s0)) = out00;
3498 *((__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s0)) = out01;
3499 *((__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s0)) = out02;
3500 *((__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s0)) = out03;
3501 *((__global float *)(dst_ptr + mult_y.s1 * 0 * dst_stride_y + offset.s1)) = out10;
3502 *((__global float *)(dst_ptr + mult_y.s1 * 1 * dst_stride_y + offset.s1)) = out11;
3503 *((__global float *)(dst_ptr + mult_y.s1 * 2 * dst_stride_y + offset.s1)) = out12;
3504 *((__global float *)(dst_ptr + mult_y.s1 * 3 * dst_stride_y + offset.s1)) = out13;
3505 *((__global float *)(dst_ptr + mult_y.s2 * 0 * dst_stride_y + offset.s2)) = out20;
3506 *((__global float *)(dst_ptr + mult_y.s2 * 1 * dst_stride_y + offset.s2)) = out21;
3507 *((__global float *)(dst_ptr + mult_y.s2 * 2 * dst_stride_y + offset.s2)) = out22;
3508 *((__global float *)(dst_ptr + mult_y.s2 * 3 * dst_stride_y + offset.s2)) = out23;
3509 *((__global float *)(dst_ptr + mult_y.s3 * 0 * dst_stride_y + offset.s3)) = out30;
3510 *((__global float *)(dst_ptr + mult_y.s3 * 1 * dst_stride_y + offset.s3)) = out31;
3511 *((__global float *)(dst_ptr + mult_y.s3 * 2 * dst_stride_y + offset.s3)) = out32;
3512 *((__global float *)(dst_ptr + mult_y.s3 * 3 * dst_stride_y + offset.s3)) = out33;
Giorgio Arena3695f9a2018-04-23 17:41:22 +01003513}
3514
Giorgio Arenadd038702018-04-16 11:20:11 +01003515#define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \
3516 ({ \
3517 comm_fact.s0 = d1 + d2; \
3518 comm_fact.s1 = d3 + d4; \
3519 comm_fact.s2 = d5 + d6; \
3520 \
3521 col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \
3522 col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \
3523 \
3524 comm_fact.s0 = d1 - d2; \
3525 comm_fact.s1 = d3 - d4; \
3526 comm_fact.s2 = d5 - d6; \
3527 \
3528 col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \
3529 col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \
3530 })
3531
Giorgio Arenac42f28d2018-04-26 11:33:05 +01003532/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 5x5 and the data layout is NCHW
Giorgio Arenadd038702018-04-16 11:20:11 +01003533 *
3534 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
3535 *
3536 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
3537 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3538 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3539 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3540 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3541 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3542 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3543 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3544 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
3545 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3546 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3547 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3548 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3549 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3550 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3551 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3552 */
3553__kernel void winograd_output_transform_4x4_5x5_nchw(
3554 TENSOR3D_DECLARATION(src),
3555 TENSOR3D_DECLARATION(dst)
3556#if defined(HAS_BIAS)
3557 ,
3558 VECTOR_DECLARATION(bias)
3559#endif // defined(HAS_BIAS)
3560)
3561{
3562 // Each thread stores a 4x4 tile
3563 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3564
3565 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
3566
3567 // Load the values across the 64 channels to compose the 8x8 input tile
3568 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
3569 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
3570 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
3571 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
3572 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
3573 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
3574 float d06 = *((__global float *)(src_addr + 6 * src_stride_z));
3575 float d07 = *((__global float *)(src_addr + 7 * src_stride_z));
3576
3577 float d10 = *((__global float *)(src_addr + 8 * src_stride_z));
3578 float d11 = *((__global float *)(src_addr + 9 * src_stride_z));
3579 float d12 = *((__global float *)(src_addr + 10 * src_stride_z));
3580 float d13 = *((__global float *)(src_addr + 11 * src_stride_z));
3581 float d14 = *((__global float *)(src_addr + 12 * src_stride_z));
3582 float d15 = *((__global float *)(src_addr + 13 * src_stride_z));
3583 float d16 = *((__global float *)(src_addr + 14 * src_stride_z));
3584 float d17 = *((__global float *)(src_addr + 15 * src_stride_z));
3585
3586 float d20 = *((__global float *)(src_addr + 16 * src_stride_z));
3587 float d21 = *((__global float *)(src_addr + 17 * src_stride_z));
3588 float d22 = *((__global float *)(src_addr + 18 * src_stride_z));
3589 float d23 = *((__global float *)(src_addr + 19 * src_stride_z));
3590 float d24 = *((__global float *)(src_addr + 20 * src_stride_z));
3591 float d25 = *((__global float *)(src_addr + 21 * src_stride_z));
3592 float d26 = *((__global float *)(src_addr + 22 * src_stride_z));
3593 float d27 = *((__global float *)(src_addr + 23 * src_stride_z));
3594
3595 float d30 = *((__global float *)(src_addr + 24 * src_stride_z));
3596 float d31 = *((__global float *)(src_addr + 25 * src_stride_z));
3597 float d32 = *((__global float *)(src_addr + 26 * src_stride_z));
3598 float d33 = *((__global float *)(src_addr + 27 * src_stride_z));
3599 float d34 = *((__global float *)(src_addr + 28 * src_stride_z));
3600 float d35 = *((__global float *)(src_addr + 29 * src_stride_z));
3601 float d36 = *((__global float *)(src_addr + 30 * src_stride_z));
3602 float d37 = *((__global float *)(src_addr + 31 * src_stride_z));
3603
3604 float d40 = *((__global float *)(src_addr + 32 * src_stride_z));
3605 float d41 = *((__global float *)(src_addr + 33 * src_stride_z));
3606 float d42 = *((__global float *)(src_addr + 34 * src_stride_z));
3607 float d43 = *((__global float *)(src_addr + 35 * src_stride_z));
3608 float d44 = *((__global float *)(src_addr + 36 * src_stride_z));
3609 float d45 = *((__global float *)(src_addr + 37 * src_stride_z));
3610 float d46 = *((__global float *)(src_addr + 38 * src_stride_z));
3611 float d47 = *((__global float *)(src_addr + 39 * src_stride_z));
3612
3613 float d50 = *((__global float *)(src_addr + 40 * src_stride_z));
3614 float d51 = *((__global float *)(src_addr + 41 * src_stride_z));
3615 float d52 = *((__global float *)(src_addr + 42 * src_stride_z));
3616 float d53 = *((__global float *)(src_addr + 43 * src_stride_z));
3617 float d54 = *((__global float *)(src_addr + 44 * src_stride_z));
3618 float d55 = *((__global float *)(src_addr + 45 * src_stride_z));
3619 float d56 = *((__global float *)(src_addr + 46 * src_stride_z));
3620 float d57 = *((__global float *)(src_addr + 47 * src_stride_z));
3621
3622 float d60 = *((__global float *)(src_addr + 48 * src_stride_z));
3623 float d61 = *((__global float *)(src_addr + 49 * src_stride_z));
3624 float d62 = *((__global float *)(src_addr + 50 * src_stride_z));
3625 float d63 = *((__global float *)(src_addr + 51 * src_stride_z));
3626 float d64 = *((__global float *)(src_addr + 52 * src_stride_z));
3627 float d65 = *((__global float *)(src_addr + 53 * src_stride_z));
3628 float d66 = *((__global float *)(src_addr + 54 * src_stride_z));
3629 float d67 = *((__global float *)(src_addr + 55 * src_stride_z));
3630
3631 float d70 = *((__global float *)(src_addr + 56 * src_stride_z));
3632 float d71 = *((__global float *)(src_addr + 57 * src_stride_z));
3633 float d72 = *((__global float *)(src_addr + 58 * src_stride_z));
3634 float d73 = *((__global float *)(src_addr + 59 * src_stride_z));
3635 float d74 = *((__global float *)(src_addr + 60 * src_stride_z));
3636 float d75 = *((__global float *)(src_addr + 61 * src_stride_z));
3637 float d76 = *((__global float *)(src_addr + 62 * src_stride_z));
3638 float d77 = *((__global float *)(src_addr + 63 * src_stride_z));
3639
3640 // Compute the 8x4 intermediate tensor
3641 float4 comm_fact0, comm_fact1, comm_fact2;
3642 float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
3643
3644 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
3645 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
3646 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
3647 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
3648 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
3649 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
3650 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
3651 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
3652
3653 // Compute the 4x4 output tile
3654 comm_fact0 = tmp_col1 + tmp_col2;
3655 comm_fact1 = tmp_col3 + tmp_col4;
3656 comm_fact2 = tmp_col5 + tmp_col6;
3657
3658 float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
3659 float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
3660
3661 comm_fact0 = tmp_col1 - tmp_col2;
3662 comm_fact1 = tmp_col3 - tmp_col4;
3663 comm_fact2 = tmp_col5 - tmp_col6;
3664
3665 float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
3666 float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
3667
3668 int y_in = get_global_id(1);
3669 int x_out = (y_in % NUM_TILES_X) * 4;
3670 int y_out = (y_in / NUM_TILES_X) * 4;
3671 int z_out = get_global_id(0);
3672
3673#if defined(HAS_BIAS)
3674 // Add bias
3675 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
3676
3677 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
3678
3679 out_col0 += (float4)b;
3680 out_col1 += (float4)b;
3681 out_col2 += (float4)b;
3682 out_col3 += (float4)b;
3683#endif // defined(HAS_BIAS)
3684
3685 // Get output address
3686 __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z;
3687
3688 // Store the 4x4 output tile
3689 *(__global float *)(dst_addr + 0 * dst_stride_x + 0 * dst_stride_y) = out_col0.s0;
3690 *(__global float *)(dst_addr + 1 * dst_stride_x + 0 * dst_stride_y) = out_col1.s0;
3691 *(__global float *)(dst_addr + 2 * dst_stride_x + 0 * dst_stride_y) = out_col2.s0;
3692 *(__global float *)(dst_addr + 3 * dst_stride_x + 0 * dst_stride_y) = out_col3.s0;
3693 *(__global float *)(dst_addr + 0 * dst_stride_x + 1 * dst_stride_y) = out_col0.s1;
3694 *(__global float *)(dst_addr + 1 * dst_stride_x + 1 * dst_stride_y) = out_col1.s1;
3695 *(__global float *)(dst_addr + 2 * dst_stride_x + 1 * dst_stride_y) = out_col2.s1;
3696 *(__global float *)(dst_addr + 3 * dst_stride_x + 1 * dst_stride_y) = out_col3.s1;
3697 *(__global float *)(dst_addr + 0 * dst_stride_x + 2 * dst_stride_y) = out_col0.s2;
3698 *(__global float *)(dst_addr + 1 * dst_stride_x + 2 * dst_stride_y) = out_col1.s2;
3699 *(__global float *)(dst_addr + 2 * dst_stride_x + 2 * dst_stride_y) = out_col2.s2;
3700 *(__global float *)(dst_addr + 3 * dst_stride_x + 2 * dst_stride_y) = out_col3.s2;
3701 *(__global float *)(dst_addr + 0 * dst_stride_x + 3 * dst_stride_y) = out_col0.s3;
3702 *(__global float *)(dst_addr + 1 * dst_stride_x + 3 * dst_stride_y) = out_col1.s3;
3703 *(__global float *)(dst_addr + 2 * dst_stride_x + 3 * dst_stride_y) = out_col2.s3;
3704 *(__global float *)(dst_addr + 3 * dst_stride_x + 3 * dst_stride_y) = out_col3.s3;
3705}
Giorgio Arena7210fe82018-06-08 12:24:14 +01003706
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01003707/** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 5x5 and the data layout is NHWC
Giorgio Arena7210fe82018-06-08 12:24:14 +01003708 *
3709 * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16
3710 *
3711 * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32
3712 * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes)
3713 * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
3714 * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes)
3715 * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
3716 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3717 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3718 * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor
3719 * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr
3720 * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes)
3721 * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
3722 * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes)
3723 * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes)
3724 * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes)
3725 * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes)
3726 * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor
3727 */
3728__kernel void winograd_output_transform_4x4_5x5_nhwc(
3729 TENSOR3D_DECLARATION(src),
3730 TENSOR3D_DECLARATION(dst),
3731#if defined(HAS_BIAS)
3732 VECTOR_DECLARATION(bias),
3733#endif // defined(HAS_BIAS)
3734 int dst_size)
3735{
3736 // Each thread stores a 4x4 tile
3737 Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src);
3738
3739 const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0);
3740
3741 // Load the values across the 64 channels to compose the 8x8 input tile
3742 float d00 = *((__global float *)(src_addr + 0 * src_stride_z));
3743 float d01 = *((__global float *)(src_addr + 1 * src_stride_z));
3744 float d02 = *((__global float *)(src_addr + 2 * src_stride_z));
3745 float d03 = *((__global float *)(src_addr + 3 * src_stride_z));
3746 float d04 = *((__global float *)(src_addr + 4 * src_stride_z));
3747 float d05 = *((__global float *)(src_addr + 5 * src_stride_z));
3748 float d06 = *((__global float *)(src_addr + 6 * src_stride_z));
3749 float d07 = *((__global float *)(src_addr + 7 * src_stride_z));
3750
3751 float d10 = *((__global float *)(src_addr + 8 * src_stride_z));
3752 float d11 = *((__global float *)(src_addr + 9 * src_stride_z));
3753 float d12 = *((__global float *)(src_addr + 10 * src_stride_z));
3754 float d13 = *((__global float *)(src_addr + 11 * src_stride_z));
3755 float d14 = *((__global float *)(src_addr + 12 * src_stride_z));
3756 float d15 = *((__global float *)(src_addr + 13 * src_stride_z));
3757 float d16 = *((__global float *)(src_addr + 14 * src_stride_z));
3758 float d17 = *((__global float *)(src_addr + 15 * src_stride_z));
3759
3760 float d20 = *((__global float *)(src_addr + 16 * src_stride_z));
3761 float d21 = *((__global float *)(src_addr + 17 * src_stride_z));
3762 float d22 = *((__global float *)(src_addr + 18 * src_stride_z));
3763 float d23 = *((__global float *)(src_addr + 19 * src_stride_z));
3764 float d24 = *((__global float *)(src_addr + 20 * src_stride_z));
3765 float d25 = *((__global float *)(src_addr + 21 * src_stride_z));
3766 float d26 = *((__global float *)(src_addr + 22 * src_stride_z));
3767 float d27 = *((__global float *)(src_addr + 23 * src_stride_z));
3768
3769 float d30 = *((__global float *)(src_addr + 24 * src_stride_z));
3770 float d31 = *((__global float *)(src_addr + 25 * src_stride_z));
3771 float d32 = *((__global float *)(src_addr + 26 * src_stride_z));
3772 float d33 = *((__global float *)(src_addr + 27 * src_stride_z));
3773 float d34 = *((__global float *)(src_addr + 28 * src_stride_z));
3774 float d35 = *((__global float *)(src_addr + 29 * src_stride_z));
3775 float d36 = *((__global float *)(src_addr + 30 * src_stride_z));
3776 float d37 = *((__global float *)(src_addr + 31 * src_stride_z));
3777
3778 float d40 = *((__global float *)(src_addr + 32 * src_stride_z));
3779 float d41 = *((__global float *)(src_addr + 33 * src_stride_z));
3780 float d42 = *((__global float *)(src_addr + 34 * src_stride_z));
3781 float d43 = *((__global float *)(src_addr + 35 * src_stride_z));
3782 float d44 = *((__global float *)(src_addr + 36 * src_stride_z));
3783 float d45 = *((__global float *)(src_addr + 37 * src_stride_z));
3784 float d46 = *((__global float *)(src_addr + 38 * src_stride_z));
3785 float d47 = *((__global float *)(src_addr + 39 * src_stride_z));
3786
3787 float d50 = *((__global float *)(src_addr + 40 * src_stride_z));
3788 float d51 = *((__global float *)(src_addr + 41 * src_stride_z));
3789 float d52 = *((__global float *)(src_addr + 42 * src_stride_z));
3790 float d53 = *((__global float *)(src_addr + 43 * src_stride_z));
3791 float d54 = *((__global float *)(src_addr + 44 * src_stride_z));
3792 float d55 = *((__global float *)(src_addr + 45 * src_stride_z));
3793 float d56 = *((__global float *)(src_addr + 46 * src_stride_z));
3794 float d57 = *((__global float *)(src_addr + 47 * src_stride_z));
3795
3796 float d60 = *((__global float *)(src_addr + 48 * src_stride_z));
3797 float d61 = *((__global float *)(src_addr + 49 * src_stride_z));
3798 float d62 = *((__global float *)(src_addr + 50 * src_stride_z));
3799 float d63 = *((__global float *)(src_addr + 51 * src_stride_z));
3800 float d64 = *((__global float *)(src_addr + 52 * src_stride_z));
3801 float d65 = *((__global float *)(src_addr + 53 * src_stride_z));
3802 float d66 = *((__global float *)(src_addr + 54 * src_stride_z));
3803 float d67 = *((__global float *)(src_addr + 55 * src_stride_z));
3804
3805 float d70 = *((__global float *)(src_addr + 56 * src_stride_z));
3806 float d71 = *((__global float *)(src_addr + 57 * src_stride_z));
3807 float d72 = *((__global float *)(src_addr + 58 * src_stride_z));
3808 float d73 = *((__global float *)(src_addr + 59 * src_stride_z));
3809 float d74 = *((__global float *)(src_addr + 60 * src_stride_z));
3810 float d75 = *((__global float *)(src_addr + 61 * src_stride_z));
3811 float d76 = *((__global float *)(src_addr + 62 * src_stride_z));
3812 float d77 = *((__global float *)(src_addr + 63 * src_stride_z));
3813
3814 // Compute the 8x4 intermediate tensor
3815 float4 comm_fact0, comm_fact1, comm_fact2;
3816 float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7;
3817
3818 COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0);
3819 COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0);
3820 COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0);
3821 COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0);
3822 COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0);
3823 COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0);
3824 COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0);
3825 COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0);
3826
3827 // Compute the 4x4 output tile
3828 comm_fact0 = tmp_col1 + tmp_col2;
3829 comm_fact1 = tmp_col3 + tmp_col4;
3830 comm_fact2 = tmp_col5 + tmp_col6;
3831
3832 float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0;
3833 float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2;
3834
3835 comm_fact0 = tmp_col1 - tmp_col2;
3836 comm_fact1 = tmp_col3 - tmp_col4;
3837 comm_fact2 = tmp_col5 - tmp_col6;
3838
3839 float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2;
3840 float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7;
3841
3842 int y_in = get_global_id(1);
3843 int x_out = get_global_id(0);
3844 int y_out = (y_in % NUM_TILES_X) * 4;
3845 int z_out = (y_in / NUM_TILES_X) * 4;
3846
3847#if defined(HAS_BIAS)
3848 // Add bias
3849 Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias);
3850
3851 float b = (float) * ((__global float *)(vector_offset(&bias, z_out)));
3852
3853 out_col0 += (float4)b;
3854 out_col1 += (float4)b;
3855 out_col2 += (float4)b;
3856 out_col3 += (float4)b;
3857#endif // defined(HAS_BIAS)
3858
3859 // Get output address
3860 int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z);
3861 offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding).
3862 int4 mult_y = min(dst_size - offset, 1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise.
3863
3864 // Store the 4x4 output tile
3865 *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s0) = out_col0.s0;
3866 *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s0) = out_col1.s0;
3867 *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s0) = out_col2.s0;
3868 *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s0) = out_col3.s0;
3869 *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s1) = out_col0.s1;
3870 *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s1) = out_col1.s1;
3871 *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s1) = out_col2.s1;
3872 *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s1) = out_col3.s1;
3873 *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s2) = out_col0.s2;
3874 *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s2) = out_col1.s2;
3875 *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s2) = out_col2.s2;
3876 *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s2) = out_col3.s2;
3877 *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s3) = out_col0.s3;
3878 *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s3) = out_col1.s3;
3879 *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s3) = out_col2.s3;
3880 *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s3) = out_col3.s3;
3881}
Gian Marco Iodicef1c2bf02018-06-13 14:05:54 +01003882#endif // defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H)