Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (c) 2018 ARM Limited. |
| 3 | * |
| 4 | * SPDX-License-Identifier: MIT |
| 5 | * |
| 6 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | * of this software and associated documentation files (the "Software"), to |
| 8 | * deal in the Software without restriction, including without limitation the |
| 9 | * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or |
| 10 | * sell copies of the Software, and to permit persons to whom the Software is |
| 11 | * furnished to do so, subject to the following conditions: |
| 12 | * |
| 13 | * The above copyright notice and this permission notice shall be included in all |
| 14 | * copies or substantial portions of the Software. |
| 15 | * |
| 16 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | * SOFTWARE. |
| 23 | */ |
| 24 | #include "helpers.h" |
| 25 | |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 26 | #if defined(SRC_DIM_Z) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 27 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 28 | /** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 2x2/2x1/1x2 |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 29 | * |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 30 | * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 31 | * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 32 | * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 33 | * |
| 34 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 35 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 36 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 37 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 38 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 39 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 40 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 41 | * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) |
| 42 | * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes) |
| 43 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 44 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 45 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 46 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 47 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 48 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 49 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 50 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 51 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 52 | */ |
| 53 | __kernel void winograd_filter_transform_2x2_3x3_nchw( |
| 54 | TENSOR4D_DECLARATION(src), |
| 55 | TENSOR3D_DECLARATION(dst)) |
| 56 | { |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 57 | Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z); |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 58 | |
| 59 | const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0); |
| 60 | |
| 61 | // Load the values from the input tensor |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 62 | #if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) |
| 63 | float3 w0 = vload3(0, (__global float *)(src_addr)); |
| 64 | #elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
| 65 | float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)), |
| 66 | *((__global float *)(src_addr + 1 * src_stride_y)), |
| 67 | *((__global float *)(src_addr + 2 * src_stride_y))); |
| 68 | #else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 69 | float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y)); |
| 70 | float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y)); |
| 71 | float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y)); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 72 | #endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 73 | |
| 74 | // Row 0 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 75 | float4 out0 = 0.0f; |
| 76 | out0.s0 = (w0.s0); |
| 77 | out0.s1 = (w0.s0 + w0.s1 + w0.s2) * 0.5f; |
| 78 | out0.s2 = (w0.s0 + w0.s2 - w0.s1) * 0.5f; |
| 79 | out0.s3 = (w0.s2); |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 80 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 81 | #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 82 | // Row 1 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 83 | float4 out1 = 0.0f; |
| 84 | out1.s0 = (w0.s0 + w1.s0 + w2.s0) * 0.5f; |
| 85 | out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) * 0.25f; |
| 86 | out1.s2 = (w0.s0 + w1.s0 + w2.s0 + w0.s2 + w1.s2 + w2.s2 - w0.s1 - w1.s1 - w2.s1) * 0.25f; |
| 87 | out1.s3 = (w0.s2 + w1.s2 + w2.s2) * 0.5f; |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 88 | |
| 89 | // Row 2 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 90 | float4 out2 = 0.0f; |
| 91 | out2.s0 = (w0.s0 + w2.s0 - w1.s0) * 0.5f; |
| 92 | out2.s1 = (w0.s0 + w2.s0 + w0.s1 + w2.s1 + w0.s2 + w2.s2 - w1.s0 - w1.s1 - w1.s2) * 0.25f; |
| 93 | out2.s2 = (w0.s0 + w2.s0 + w1.s1 + w0.s2 + w2.s2 - w1.s0 - w0.s1 - w2.s1 - w1.s2) * 0.25f; |
| 94 | out2.s3 = (w0.s2 + w2.s2 - w1.s2) * 0.5f; |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 95 | |
| 96 | // Row 3 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 97 | float4 out3 = 0.0f; |
| 98 | out3.s0 = (w2.s0); |
| 99 | out3.s1 = (w2.s0 + w2.s1 + w2.s2) * 0.5f; |
| 100 | out3.s2 = (w2.s0 + w2.s2 - w2.s1) * 0.5f; |
| 101 | out3.s3 = (w2.s2); |
| 102 | #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 103 | |
| 104 | int z = get_global_id(2); |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 105 | int x0 = z / SRC_DIM_Z; // idx filter |
| 106 | int y0 = z % SRC_DIM_Z; // idx channel |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 107 | |
| 108 | // Get output address |
| 109 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y; |
| 110 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 111 | // Store the values across the channels |
| 112 | // 16 channels for 3x3 kernels |
| 113 | // 4 channels for 3x1 or 1x3 kernels |
| 114 | *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0; |
| 115 | *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1; |
| 116 | *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2; |
| 117 | *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3; |
| 118 | |
| 119 | #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 120 | *(__global float *)(dst_addr + 4 * dst_stride_z) = out1.s0; |
| 121 | *(__global float *)(dst_addr + 5 * dst_stride_z) = out1.s1; |
| 122 | *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s2; |
| 123 | *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s3; |
| 124 | *(__global float *)(dst_addr + 8 * dst_stride_z) = out2.s0; |
| 125 | *(__global float *)(dst_addr + 9 * dst_stride_z) = out2.s1; |
| 126 | *(__global float *)(dst_addr + 10 * dst_stride_z) = out2.s2; |
| 127 | *(__global float *)(dst_addr + 11 * dst_stride_z) = out2.s3; |
| 128 | *(__global float *)(dst_addr + 12 * dst_stride_z) = out3.s0; |
| 129 | *(__global float *)(dst_addr + 13 * dst_stride_z) = out3.s1; |
| 130 | *(__global float *)(dst_addr + 14 * dst_stride_z) = out3.s2; |
| 131 | *(__global float *)(dst_addr + 15 * dst_stride_z) = out3.s3; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 132 | #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 133 | } |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 134 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 135 | /** This OpenCL kernel performs Winograd filter transform 3x3/3x1/1x3 when the data layout is NCHW and the output tile is 4x4/4x1/1x4 |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 136 | * |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 137 | * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 138 | * @note If this kernel is used to perform Winograd filter transform 3x1, -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 139 | * @note If this kernel is used to perform Winograd filter transform 1x3, -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 140 | * |
| 141 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 142 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 143 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 144 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 145 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 146 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 147 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 148 | * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) |
| 149 | * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes) |
| 150 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 151 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 152 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 153 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 154 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 155 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 156 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 157 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 158 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 159 | */ |
| 160 | __kernel void winograd_filter_transform_4x4_3x3_nchw( |
| 161 | TENSOR4D_DECLARATION(src), |
| 162 | TENSOR3D_DECLARATION(dst)) |
| 163 | { |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 164 | Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z); |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 165 | |
| 166 | const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0); |
| 167 | |
| 168 | // Load the values from the input tensor |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 169 | #if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) |
| 170 | float3 w0 = vload3(0, (__global float *)(src_addr)); |
| 171 | #elif defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
| 172 | float3 w0 = (float3)(*((__global float *)(src_addr + 0 * src_stride_y)), |
| 173 | *((__global float *)(src_addr + 1 * src_stride_y)), |
| 174 | *((__global float *)(src_addr + 2 * src_stride_y))); |
| 175 | #else // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 176 | float3 w0 = vload3(0, (__global float *)(src_addr + 0 * src_stride_y)); |
| 177 | float3 w1 = vload3(0, (__global float *)(src_addr + 1 * src_stride_y)); |
| 178 | float3 w2 = vload3(0, (__global float *)(src_addr + 2 * src_stride_y)); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 179 | #endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 180 | |
| 181 | // Row 0 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 182 | float8 out0 = 0.0f; |
| 183 | out0.s0 = (w0.s0) / 16.f; |
| 184 | out0.s1 = (-w0.s0 - w0.s1 - w0.s2) / 24.f; |
| 185 | out0.s2 = (-w0.s0 + w0.s1 - w0.s2) / 24.f; |
| 186 | out0.s3 = (w0.s0 + 2.f * w0.s1 + 4.f * w0.s2) / 96.f; |
| 187 | out0.s4 = (w0.s0 - 2.f * w0.s1 + 4.f * w0.s2) / 96.f; |
| 188 | out0.s5 = (w0.s2) / 4.f; |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 189 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 190 | #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 191 | // Row 1 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 192 | float8 out1 = 0.0f; |
| 193 | out1.s0 = (-w0.s0 - w1.s0 - w2.s0) / 24.f; |
| 194 | out1.s1 = (w0.s0 + w1.s0 + w2.s0 + w0.s1 + w1.s1 + w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f; |
| 195 | out1.s2 = (w0.s0 + w1.s0 + w2.s0 - w0.s1 - w1.s1 - w2.s1 + w0.s2 + w1.s2 + w2.s2) / 36.f; |
| 196 | out1.s3 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (-w0.s1 - w1.s1 - w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f; |
| 197 | out1.s4 = (-w0.s0 - w1.s0 - w2.s0 + 2.f * (w0.s1 + w1.s1 + w2.s1) + 4.f * (-w0.s2 - w1.s2 - w2.s2)) / 144.f; |
| 198 | out1.s5 = (-w0.s2 - w1.s2 - w2.s2) / 6.f; |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 199 | |
| 200 | // Row 2 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 201 | float8 out2 = 0.0f; |
| 202 | out2.s0 = (-w0.s0 + w1.s0 - w2.s0) / 24.f; |
| 203 | out2.s1 = (w0.s0 - w1.s0 + w2.s0 + w0.s1 - w1.s1 + w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f; |
| 204 | out2.s2 = (w0.s0 - w1.s0 + w2.s0 - w0.s1 + w1.s1 - w2.s1 + w0.s2 - w1.s2 + w2.s2) / 36.f; |
| 205 | out2.s3 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (-w0.s1 + w1.s1 - w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f; |
| 206 | out2.s4 = (-w0.s0 + w1.s0 - w2.s0 + 2.f * (w0.s1 - w1.s1 + w2.s1) + 4.f * (-w0.s2 + w1.s2 - w2.s2)) / 144.f; |
| 207 | out2.s5 = (-w0.s2 + w1.s2 - w2.s2) / 6.f; |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 208 | |
| 209 | // Row 3 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 210 | float8 out3 = 0.0f; |
| 211 | out3.s0 = (w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) / 96.f; |
| 212 | out3.s1 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 - 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f; |
| 213 | out3.s2 = (-w0.s0 - 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 + 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 - 2.f * w1.s2 - 4.f * w2.s2) / 144.f; |
| 214 | out3.s3 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 + 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f; |
| 215 | out3.s4 = ((w0.s0 + 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 - 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2)) / 576.f; |
| 216 | out3.s5 = (w0.s2 + 2.f * w1.s2 + 4.f * w2.s2) / 24.f; |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 217 | |
| 218 | // Row 4 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 219 | float8 out4 = 0.0f; |
| 220 | out4.s0 = (w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) / 96.f; |
| 221 | out4.s1 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 - w0.s1 + 2.f * w1.s1 - 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f; |
| 222 | out4.s2 = (-w0.s0 + 2.f * w1.s0 - 4.f * w2.s0 + w0.s1 - 2.f * w1.s1 + 4.f * w2.s1 - w0.s2 + 2.f * w1.s2 - 4.f * w2.s2) / 144.f; |
| 223 | out4.s3 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (w0.s1 - 2.f * w1.s1 + 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f; |
| 224 | out4.s4 = ((w0.s0 - 2.f * w1.s0 + 4.f * w2.s0) + 2.f * (-w0.s1 + 2.f * w1.s1 - 4.f * w2.s1) + 4.f * (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2)) / 576.f; |
| 225 | out4.s5 = (w0.s2 - 2.f * w1.s2 + 4.f * w2.s2) / 24.f; |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 226 | |
| 227 | // Row 5 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 228 | float8 out5 = 0.0f; |
| 229 | out5.s0 = (w2.s0) / 4.f; |
| 230 | out5.s1 = (-w2.s0 - w2.s1 - w2.s2) / 6.f; |
| 231 | out5.s2 = (-w2.s0 + w2.s1 - w2.s2) / 6.f; |
| 232 | out5.s3 = (w2.s0 + 2.f * w2.s1 + 4.f * w2.s2) / 24.f; |
| 233 | out5.s4 = (w2.s0 - 2.f * w2.s1 + 4.f * w2.s2) / 24.f; |
| 234 | out5.s5 = (w2.s2); |
| 235 | #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 236 | |
| 237 | int z = get_global_id(2); |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 238 | int x0 = z / SRC_DIM_Z; // idx filter |
| 239 | int y0 = z % SRC_DIM_Z; // idx channel |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 240 | |
| 241 | // Get output address |
| 242 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y; |
| 243 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 244 | // Store the values across the channels |
| 245 | // 36 channels for 3x3 kernels |
| 246 | // 6 channels for 3x1 or 1x3 kernels |
| 247 | *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0; |
| 248 | *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1; |
| 249 | *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2; |
| 250 | *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3; |
| 251 | *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4; |
| 252 | *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5; |
| 253 | |
| 254 | #if !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 255 | *(__global float *)(dst_addr + 6 * dst_stride_z) = out1.s0; |
| 256 | *(__global float *)(dst_addr + 7 * dst_stride_z) = out1.s1; |
| 257 | *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s2; |
| 258 | *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s3; |
| 259 | *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s4; |
| 260 | *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s5; |
| 261 | *(__global float *)(dst_addr + 12 * dst_stride_z) = out2.s0; |
| 262 | *(__global float *)(dst_addr + 13 * dst_stride_z) = out2.s1; |
| 263 | *(__global float *)(dst_addr + 14 * dst_stride_z) = out2.s2; |
| 264 | *(__global float *)(dst_addr + 15 * dst_stride_z) = out2.s3; |
| 265 | *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s4; |
| 266 | *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s5; |
| 267 | *(__global float *)(dst_addr + 18 * dst_stride_z) = out3.s0; |
| 268 | *(__global float *)(dst_addr + 19 * dst_stride_z) = out3.s1; |
| 269 | *(__global float *)(dst_addr + 20 * dst_stride_z) = out3.s2; |
| 270 | *(__global float *)(dst_addr + 21 * dst_stride_z) = out3.s3; |
| 271 | *(__global float *)(dst_addr + 22 * dst_stride_z) = out3.s4; |
| 272 | *(__global float *)(dst_addr + 23 * dst_stride_z) = out3.s5; |
| 273 | *(__global float *)(dst_addr + 24 * dst_stride_z) = out4.s0; |
| 274 | *(__global float *)(dst_addr + 25 * dst_stride_z) = out4.s1; |
| 275 | *(__global float *)(dst_addr + 26 * dst_stride_z) = out4.s2; |
| 276 | *(__global float *)(dst_addr + 27 * dst_stride_z) = out4.s3; |
| 277 | *(__global float *)(dst_addr + 28 * dst_stride_z) = out4.s4; |
| 278 | *(__global float *)(dst_addr + 29 * dst_stride_z) = out4.s5; |
| 279 | *(__global float *)(dst_addr + 30 * dst_stride_z) = out5.s0; |
| 280 | *(__global float *)(dst_addr + 31 * dst_stride_z) = out5.s1; |
| 281 | *(__global float *)(dst_addr + 32 * dst_stride_z) = out5.s2; |
| 282 | *(__global float *)(dst_addr + 33 * dst_stride_z) = out5.s3; |
| 283 | *(__global float *)(dst_addr + 34 * dst_stride_z) = out5.s4; |
| 284 | *(__global float *)(dst_addr + 35 * dst_stride_z) = out5.s5; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 285 | #endif // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
Giorgio Arena | 2d9de0a | 2018-03-15 17:58:20 +0000 | [diff] [blame] | 286 | } |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 287 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 288 | #if defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) |
| 289 | /** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 2x1 |
| 290 | * |
| 291 | * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64 |
| 292 | * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform |
| 293 | * |
| 294 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 295 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 296 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 297 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 298 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 299 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 300 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 301 | * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) |
| 302 | * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes) |
| 303 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 304 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 305 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 306 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 307 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 308 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 309 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 310 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 311 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 312 | */ |
| 313 | __kernel void winograd_filter_transform_2x1_3x1_nchw( |
| 314 | TENSOR4D_DECLARATION(src), |
| 315 | TENSOR3D_DECLARATION(dst)) |
| 316 | { |
| 317 | winograd_filter_transform_2x2_3x3_nchw(src_ptr, |
| 318 | src_stride_x, |
| 319 | src_step_x, |
| 320 | src_stride_y, |
| 321 | src_step_y, |
| 322 | src_stride_z, |
| 323 | src_step_z, |
| 324 | src_stride_w, |
| 325 | src_step_w, |
| 326 | src_offset_first_element_in_bytes, |
| 327 | dst_ptr, |
| 328 | dst_stride_x, |
| 329 | dst_step_x, |
| 330 | dst_stride_y, |
| 331 | dst_step_y, |
| 332 | dst_stride_z, |
| 333 | dst_step_z, |
| 334 | dst_offset_first_element_in_bytes); |
| 335 | } |
| 336 | |
| 337 | /** This OpenCL kernel performs Winograd filter transform 3x1 when the data layout is NCHW and the output tile is 4x1 |
| 338 | * |
| 339 | * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64 |
| 340 | * @note -DWINOGRAD_FILTER_TRANSFORM_HORIZONTAL has to be passed at compile time to perform Winograd Filter Transform |
| 341 | * |
| 342 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 343 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 344 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 345 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 346 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 347 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 348 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 349 | * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) |
| 350 | * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes) |
| 351 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 352 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 353 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 354 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 355 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 356 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 357 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 358 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 359 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 360 | */ |
| 361 | __kernel void winograd_filter_transform_4x1_3x1_nchw( |
| 362 | TENSOR4D_DECLARATION(src), |
| 363 | TENSOR3D_DECLARATION(dst)) |
| 364 | { |
| 365 | winograd_filter_transform_4x4_3x3_nchw(src_ptr, |
| 366 | src_stride_x, |
| 367 | src_step_x, |
| 368 | src_stride_y, |
| 369 | src_step_y, |
| 370 | src_stride_z, |
| 371 | src_step_z, |
| 372 | src_stride_w, |
| 373 | src_step_w, |
| 374 | src_offset_first_element_in_bytes, |
| 375 | dst_ptr, |
| 376 | dst_stride_x, |
| 377 | dst_step_x, |
| 378 | dst_stride_y, |
| 379 | dst_step_y, |
| 380 | dst_stride_z, |
| 381 | dst_step_z, |
| 382 | dst_offset_first_element_in_bytes); |
| 383 | } |
| 384 | #endif // defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) |
| 385 | |
| 386 | #if defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
| 387 | /** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x2 |
| 388 | * |
| 389 | * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64 |
| 390 | * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform |
| 391 | * |
| 392 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 393 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 394 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 395 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 396 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 397 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 398 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 399 | * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) |
| 400 | * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes) |
| 401 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 402 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 403 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 404 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 405 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 406 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 407 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 408 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 409 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 410 | */ |
| 411 | __kernel void winograd_filter_transform_1x2_1x3_nchw( |
| 412 | TENSOR4D_DECLARATION(src), |
| 413 | TENSOR3D_DECLARATION(dst)) |
| 414 | { |
| 415 | winograd_filter_transform_2x2_3x3_nchw(src_ptr, |
| 416 | src_stride_x, |
| 417 | src_step_x, |
| 418 | src_stride_y, |
| 419 | src_step_y, |
| 420 | src_stride_z, |
| 421 | src_step_z, |
| 422 | src_stride_w, |
| 423 | src_step_w, |
| 424 | src_offset_first_element_in_bytes, |
| 425 | dst_ptr, |
| 426 | dst_stride_x, |
| 427 | dst_step_x, |
| 428 | dst_stride_y, |
| 429 | dst_step_y, |
| 430 | dst_stride_z, |
| 431 | dst_step_z, |
| 432 | dst_offset_first_element_in_bytes); |
| 433 | } |
| 434 | |
| 435 | /** This OpenCL kernel performs Winograd filter transform 1x3 when the data layout is NCHW and the output tile is 1x4 |
| 436 | * |
| 437 | * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64 |
| 438 | * @note -DWINOGRAD_FILTER_TRANSFORM_VERTICAL has to be passed at compile time to perform Winograd Filter Transform |
| 439 | * |
| 440 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 441 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 442 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 443 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 444 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 445 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 446 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 447 | * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) |
| 448 | * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes) |
| 449 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 450 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 451 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 452 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 453 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 454 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 455 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 456 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 457 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 458 | */ |
| 459 | __kernel void winograd_filter_transform_1x4_1x3_nchw( |
| 460 | TENSOR4D_DECLARATION(src), |
| 461 | TENSOR3D_DECLARATION(dst)) |
| 462 | { |
| 463 | winograd_filter_transform_4x4_3x3_nchw(src_ptr, |
| 464 | src_stride_x, |
| 465 | src_step_x, |
| 466 | src_stride_y, |
| 467 | src_step_y, |
| 468 | src_stride_z, |
| 469 | src_step_z, |
| 470 | src_stride_w, |
| 471 | src_step_w, |
| 472 | src_offset_first_element_in_bytes, |
| 473 | dst_ptr, |
| 474 | dst_stride_x, |
| 475 | dst_step_x, |
| 476 | dst_stride_y, |
| 477 | dst_step_y, |
| 478 | dst_stride_z, |
| 479 | dst_step_z, |
| 480 | dst_offset_first_element_in_bytes); |
| 481 | } |
| 482 | #endif // defined(WINOGRAD_FILTER_TRANSFORM_VERTICAL) |
| 483 | |
Giorgio Arena | c42f28d | 2018-04-26 11:33:05 +0100 | [diff] [blame] | 484 | /** This OpenCL kernel performs Winograd filter transform 3x3 when the data layout is NHWC and the output tile is 4x4 |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 485 | * |
| 486 | * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64 |
| 487 | * |
| 488 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 489 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 490 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 491 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 492 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 493 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 494 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 495 | * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) |
| 496 | * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes) |
| 497 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 498 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 499 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 500 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 501 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 502 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 503 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 504 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 505 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 506 | */ |
| 507 | __kernel void winograd_filter_transform_4x4_3x3_nhwc( |
| 508 | TENSOR4D_DECLARATION(src), |
| 509 | TENSOR3D_DECLARATION(dst)) |
| 510 | { |
| 511 | Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z); |
| 512 | |
| 513 | const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * src_step_x + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w; |
| 514 | |
| 515 | // Load the values from the input tensor |
| 516 | float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y)); |
| 517 | float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y)); |
| 518 | float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y)); |
| 519 | float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y)); |
| 520 | float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y)); |
| 521 | float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y)); |
| 522 | float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y)); |
| 523 | float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y)); |
| 524 | float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y)); |
| 525 | |
| 526 | // Transform the 3x3 tile in a 6x6 tile |
| 527 | float out00, out01, out02, out03, out04, out05; |
| 528 | float out10, out11, out12, out13, out14, out15; |
| 529 | float out20, out21, out22, out23, out24, out25; |
| 530 | float out30, out31, out32, out33, out34, out35; |
| 531 | float out40, out41, out42, out43, out44, out45; |
| 532 | float out50, out51, out52, out53, out54, out55; |
| 533 | |
| 534 | out00 = out01 = out02 = out03 = out04 = out05 = 0.f; |
| 535 | out10 = out11 = out12 = out13 = out14 = out15 = 0.f; |
| 536 | out20 = out21 = out22 = out23 = out24 = out25 = 0.f; |
| 537 | out30 = out31 = out32 = out33 = out34 = out35 = 0.f; |
| 538 | out40 = out41 = out42 = out43 = out44 = out45 = 0.f; |
| 539 | out50 = out51 = out52 = out53 = out54 = out55 = 0.f; |
| 540 | |
| 541 | // Row 0 |
| 542 | out00 = (w00) / 16.f; |
| 543 | out01 = (-w00 - w01 - w02) / 24.f; |
| 544 | out02 = (-w00 + w01 - w02) / 24.f; |
| 545 | out03 = (w00 + 2.f * w01 + 4.f * w02) / 96.f; |
| 546 | out04 = (w00 - 2.f * w01 + 4.f * w02) / 96.f; |
| 547 | out05 = (w02) / 4.f; |
| 548 | |
| 549 | // Row 1 |
| 550 | out10 = (-w00 - w10 - w20) / 24.f; |
| 551 | out11 = (w00 + w10 + w20 + w01 + w11 + w21 + w02 + w12 + w22) / 36.f; |
| 552 | out12 = (w00 + w10 + w20 - w01 - w11 - w21 + w02 + w12 + w22) / 36.f; |
| 553 | out13 = (-w00 - w10 - w20 + 2.f * (-w01 - w11 - w21) + 4.f * (-w02 - w12 - w22)) / 144.f; |
| 554 | out14 = (-w00 - w10 - w20 + 2.f * (w01 + w11 + w21) + 4.f * (-w02 - w12 - w22)) / 144.f; |
| 555 | out15 = (-w02 - w12 - w22) / 6.f; |
| 556 | |
| 557 | // Row 2 |
| 558 | out20 = (-w00 + w10 - w20) / 24.f; |
| 559 | out21 = (w00 - w10 + w20 + w01 - w11 + w21 + w02 - w12 + w22) / 36.f; |
| 560 | out22 = (w00 - w10 + w20 - w01 + w11 - w21 + w02 - w12 + w22) / 36.f; |
| 561 | out23 = (-w00 + w10 - w20 + 2.f * (-w01 + w11 - w21) + 4.f * (-w02 + w12 - w22)) / 144.f; |
| 562 | out24 = (-w00 + w10 - w20 + 2.f * (w01 - w11 + w21) + 4.f * (-w02 + w12 - w22)) / 144.f; |
| 563 | out25 = (-w02 + w12 - w22) / 6.f; |
| 564 | |
| 565 | // Row 3 |
| 566 | out30 = (w00 + 2.f * w10 + 4.f * w20) / 96.f; |
| 567 | out31 = (-w00 - 2.f * w10 - 4.f * w20 - w01 - 2.f * w11 - 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f; |
| 568 | out32 = (-w00 - 2.f * w10 - 4.f * w20 + w01 + 2.f * w11 + 4.f * w21 - w02 - 2.f * w12 - 4.f * w22) / 144.f; |
| 569 | out33 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (w01 + 2.f * w11 + 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f; |
| 570 | out34 = ((w00 + 2.f * w10 + 4.f * w20) + 2.f * (-w01 - 2.f * w11 - 4.f * w21) + 4.f * (w02 + 2.f * w12 + 4.f * w22)) / 576.f; |
| 571 | out35 = (w02 + 2.f * w12 + 4.f * w22) / 24.f; |
| 572 | |
| 573 | // Row 4 |
| 574 | out40 = (w00 - 2.f * w10 + 4.f * w20) / 96.f; |
| 575 | out41 = (-w00 + 2.f * w10 - 4.f * w20 - w01 + 2.f * w11 - 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f; |
| 576 | out42 = (-w00 + 2.f * w10 - 4.f * w20 + w01 - 2.f * w11 + 4.f * w21 - w02 + 2.f * w12 - 4.f * w22) / 144.f; |
| 577 | out43 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (w01 - 2.f * w11 + 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f; |
| 578 | out44 = ((w00 - 2.f * w10 + 4.f * w20) + 2.f * (-w01 + 2.f * w11 - 4.f * w21) + 4.f * (w02 - 2.f * w12 + 4.f * w22)) / 576.f; |
| 579 | out45 = (w02 - 2.f * w12 + 4.f * w22) / 24.f; |
| 580 | |
| 581 | // Row 5 |
| 582 | out50 = (w20) / 4.f; |
| 583 | out51 = (-w20 - w21 - w22) / 6.f; |
| 584 | out52 = (-w20 + w21 - w22) / 6.f; |
| 585 | out53 = (w20 + 2.f * w21 + 4.f * w22) / 24.f; |
| 586 | out54 = (w20 - 2.f * w21 + 4.f * w22) / 24.f; |
| 587 | out55 = (w22); |
| 588 | |
| 589 | int x0 = get_global_id(2); // idx filter |
| 590 | int y0 = get_global_id(0); // idx channel |
| 591 | |
| 592 | // Get output address |
| 593 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y; |
| 594 | |
| 595 | // Store the values across the channels |
| 596 | *(__global float *)(dst_addr + 0 * dst_stride_z) = out00; |
| 597 | *(__global float *)(dst_addr + 1 * dst_stride_z) = out01; |
| 598 | *(__global float *)(dst_addr + 2 * dst_stride_z) = out02; |
| 599 | *(__global float *)(dst_addr + 3 * dst_stride_z) = out03; |
| 600 | *(__global float *)(dst_addr + 4 * dst_stride_z) = out04; |
| 601 | *(__global float *)(dst_addr + 5 * dst_stride_z) = out05; |
| 602 | *(__global float *)(dst_addr + 6 * dst_stride_z) = out10; |
| 603 | *(__global float *)(dst_addr + 7 * dst_stride_z) = out11; |
| 604 | *(__global float *)(dst_addr + 8 * dst_stride_z) = out12; |
| 605 | *(__global float *)(dst_addr + 9 * dst_stride_z) = out13; |
| 606 | *(__global float *)(dst_addr + 10 * dst_stride_z) = out14; |
| 607 | *(__global float *)(dst_addr + 11 * dst_stride_z) = out15; |
| 608 | *(__global float *)(dst_addr + 12 * dst_stride_z) = out20; |
| 609 | *(__global float *)(dst_addr + 13 * dst_stride_z) = out21; |
| 610 | *(__global float *)(dst_addr + 14 * dst_stride_z) = out22; |
| 611 | *(__global float *)(dst_addr + 15 * dst_stride_z) = out23; |
| 612 | *(__global float *)(dst_addr + 16 * dst_stride_z) = out24; |
| 613 | *(__global float *)(dst_addr + 17 * dst_stride_z) = out25; |
| 614 | *(__global float *)(dst_addr + 18 * dst_stride_z) = out30; |
| 615 | *(__global float *)(dst_addr + 19 * dst_stride_z) = out31; |
| 616 | *(__global float *)(dst_addr + 20 * dst_stride_z) = out32; |
| 617 | *(__global float *)(dst_addr + 21 * dst_stride_z) = out33; |
| 618 | *(__global float *)(dst_addr + 22 * dst_stride_z) = out34; |
| 619 | *(__global float *)(dst_addr + 23 * dst_stride_z) = out35; |
| 620 | *(__global float *)(dst_addr + 24 * dst_stride_z) = out40; |
| 621 | *(__global float *)(dst_addr + 25 * dst_stride_z) = out41; |
| 622 | *(__global float *)(dst_addr + 26 * dst_stride_z) = out42; |
| 623 | *(__global float *)(dst_addr + 27 * dst_stride_z) = out43; |
| 624 | *(__global float *)(dst_addr + 28 * dst_stride_z) = out44; |
| 625 | *(__global float *)(dst_addr + 29 * dst_stride_z) = out45; |
| 626 | *(__global float *)(dst_addr + 30 * dst_stride_z) = out50; |
| 627 | *(__global float *)(dst_addr + 31 * dst_stride_z) = out51; |
| 628 | *(__global float *)(dst_addr + 32 * dst_stride_z) = out52; |
| 629 | *(__global float *)(dst_addr + 33 * dst_stride_z) = out53; |
| 630 | *(__global float *)(dst_addr + 34 * dst_stride_z) = out54; |
| 631 | *(__global float *)(dst_addr + 35 * dst_stride_z) = out55; |
| 632 | } |
Giorgio Arena | c42f28d | 2018-04-26 11:33:05 +0100 | [diff] [blame] | 633 | /** This OpenCL kernel performs Winograd filter transform 5x5 when the data layout is NCHW and the output tile is 4x4 |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 634 | * |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 635 | * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64 |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 636 | * |
| 637 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 638 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 639 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 640 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 641 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 642 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 643 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 644 | * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) |
| 645 | * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes) |
| 646 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 647 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 648 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 649 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 650 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 651 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 652 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 653 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 654 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 655 | */ |
| 656 | __kernel void winograd_filter_transform_4x4_5x5_nchw( |
| 657 | TENSOR4D_DECLARATION(src), |
| 658 | TENSOR3D_DECLARATION(dst)) |
| 659 | { |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 660 | Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z); |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 661 | |
| 662 | const __global uchar *src_addr = tensor4D_offset(&src, 0, 0, 0, 0); |
| 663 | |
| 664 | // Load the values from the input tensor |
| 665 | const char stride_x = 4 * sizeof(float); // Used for accessing the last value in each row |
| 666 | const uchar8 stride_y = (uchar8)(0, 1, 2, 3, 4, 0, 0, 0) * (uchar8)src_stride_y; |
| 667 | |
| 668 | float4 w00 = vload4(0, (__global float *)(src_addr + stride_y.s0)); |
| 669 | float w01 = *((__global float *)(src_addr + stride_y.s0 + stride_x)); |
| 670 | float4 w10 = vload4(0, (__global float *)(src_addr + stride_y.s1)); |
| 671 | float w11 = *((__global float *)(src_addr + stride_y.s1 + stride_x)); |
| 672 | float4 w20 = vload4(0, (__global float *)(src_addr + stride_y.s2)); |
| 673 | float w21 = *((__global float *)(src_addr + stride_y.s2 + stride_x)); |
| 674 | float4 w30 = vload4(0, (__global float *)(src_addr + stride_y.s3)); |
| 675 | float w31 = *((__global float *)(src_addr + stride_y.s3 + stride_x)); |
| 676 | float4 w40 = vload4(0, (__global float *)(src_addr + stride_y.s4)); |
| 677 | float w41 = *((__global float *)(src_addr + stride_y.s4 + stride_x)); |
| 678 | |
| 679 | // Transform the 3x3 tile in a 8x8 tile |
| 680 | float8 out0 = 0.0f; |
| 681 | float8 out1 = 0.0f; |
| 682 | float8 out2 = 0.0f; |
| 683 | float8 out3 = 0.0f; |
| 684 | float8 out4 = 0.0f; |
| 685 | float8 out5 = 0.0f; |
| 686 | float8 out6 = 0.0f; |
| 687 | float8 out7 = 0.0f; |
| 688 | |
| 689 | // Row 0 |
| 690 | out0.s0 = w00.s0; |
| 691 | out0.s1 = -2.f * (w00.s0 + w00.s1 + w00.s2 + w00.s3 + w01) / 9.f; |
| 692 | out0.s2 = -2.f * (w00.s0 - w00.s1 + w00.s2 - w00.s3 + w01) / 9.f; |
| 693 | out0.s3 = (w00.s0 + 2.f * w00.s1 + 4.f * w00.s2 + 8.f * w00.s3 + 16.f * w01) / 90.f; |
| 694 | out0.s4 = (w00.s0 - 2.f * w00.s1 + 4.f * w00.s2 - 8.f * w00.s3 + 16.f * w01) / 90.f; |
| 695 | out0.s5 = (16.f * w00.s0 + 8.f * w00.s1 + 4.f * w00.s2 + 2.f * w00.s3 + w01) / 180.f; |
| 696 | out0.s6 = (16.f * w00.s0 - 8.f * w00.s1 + 4.f * w00.s2 - 2.f * w00.s3 + w01) / 180.f; |
| 697 | out0.s7 = w01; |
| 698 | |
| 699 | // Row 1 |
| 700 | out1.s0 = -2.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) / 9.f; |
| 701 | out1.s1 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + |
| 702 | (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f; |
| 703 | out1.s2 = 4.f * ((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - |
| 704 | (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 81.f; |
| 705 | out1.s3 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 8.f * |
| 706 | (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f; |
| 707 | out1.s4 = -((w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 2.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 8.f * |
| 708 | (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + 16.f * (w01 + w11 + w21 + w31 + w41)) / 405.f; |
| 709 | out1.s5 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) + 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) + 2.f * |
| 710 | (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f; |
| 711 | out1.s6 = -(16.f * (w00.s0 + w10.s0 + w20.s0 + w30.s0 + w40.s0) - 8.f * (w00.s1 + w10.s1 + w20.s1 + w30.s1 + w40.s1) + 4.f * (w00.s2 + w10.s2 + w20.s2 + w30.s2 + w40.s2) - 2.f * |
| 712 | (w00.s3 + w10.s3 + w20.s3 + w30.s3 + w40.s3) + (w01 + w11 + w21 + w31 + w41)) / 810.f; |
| 713 | out1.s7 = -2.f * (w01 + w11 + w21 + w31 + w41) / 9.f; |
| 714 | |
| 715 | // Row 2 |
| 716 | out2.s0 = -2.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) / 9.f; |
| 717 | out2.s1 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + |
| 718 | (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f; |
| 719 | out2.s2 = 4.f * ((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - |
| 720 | (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 81.f; |
| 721 | out2.s3 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 8.f * |
| 722 | (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f; |
| 723 | out2.s4 = -((w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 2.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 8.f * |
| 724 | (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + 16.f * (w01 - w11 + w21 - w31 + w41)) / 405.f; |
| 725 | out2.s5 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) + 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) + 2.f * |
| 726 | (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f; |
| 727 | out2.s6 = -(16.f * (w00.s0 - w10.s0 + w20.s0 - w30.s0 + w40.s0) - 8.f * (w00.s1 - w10.s1 + w20.s1 - w30.s1 + w40.s1) + 4.f * (w00.s2 - w10.s2 + w20.s2 - w30.s2 + w40.s2) - 2.f * |
| 728 | (w00.s3 - w10.s3 + w20.s3 - w30.s3 + w40.s3) + (w01 - w11 + w21 - w31 + w41)) / 810.f; |
| 729 | out2.s7 = -2.f * (w01 - w11 + w21 - w31 + w41) / 9.f; |
| 730 | |
| 731 | // Row 3 |
| 732 | out3.s0 = (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) / 90.f; |
| 733 | out3.s1 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + |
| 734 | (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + |
| 735 | (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f; |
| 736 | out3.s2 = -((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + |
| 737 | (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + |
| 738 | (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 405.f; |
| 739 | out3.s3 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f * |
| 740 | (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f * |
| 741 | (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f; |
| 742 | out3.s4 = ((w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f * |
| 743 | (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + 16.f * |
| 744 | (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 8100.f; |
| 745 | out3.s5 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f * |
| 746 | (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + |
| 747 | (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f; |
| 748 | out3.s6 = (16.f * (w00.s0 + 2.f * w10.s0 + 4.f * w20.s0 + 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 + 2.f * w10.s1 + 4.f * w20.s1 + 8.f * w30.s1 + 16.f * w40.s1) + 4.f * |
| 749 | (w00.s2 + 2.f * w10.s2 + 4.f * w20.s2 + 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 + 2.f * w10.s3 + 4.f * w20.s3 + 8.f * w30.s3 + 16.f * w40.s3) + |
| 750 | (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41)) / 16200.f; |
| 751 | out3.s7 = (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) / 90.f; |
| 752 | |
| 753 | // Row 4 |
| 754 | out4.s0 = (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) / 90.f; |
| 755 | out4.s1 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + |
| 756 | (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + |
| 757 | (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f; |
| 758 | out4.s2 = -((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + |
| 759 | (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + |
| 760 | (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 405.f; |
| 761 | out4.s3 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f * |
| 762 | (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f * |
| 763 | (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f; |
| 764 | out4.s4 = ((w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 2.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f * |
| 765 | (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 8.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + 16.f * |
| 766 | (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 8100.f; |
| 767 | out4.s5 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) + 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f * |
| 768 | (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) + 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + |
| 769 | (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f; |
| 770 | out4.s6 = (16.f * (w00.s0 - 2.f * w10.s0 + 4.f * w20.s0 - 8.f * w30.s0 + 16.f * w40.s0) - 8.f * (w00.s1 - 2.f * w10.s1 + 4.f * w20.s1 - 8.f * w30.s1 + 16.f * w40.s1) + 4.f * |
| 771 | (w00.s2 - 2.f * w10.s2 + 4.f * w20.s2 - 8.f * w30.s2 + 16.f * w40.s2) - 2.f * (w00.s3 - 2.f * w10.s3 + 4.f * w20.s3 - 8.f * w30.s3 + 16.f * w40.s3) + |
| 772 | (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41)) / 16200.f; |
| 773 | out4.s7 = (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) / 90.f; |
| 774 | |
| 775 | // Row 5 |
| 776 | out5.s0 = (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) / 180.f; |
| 777 | out5.s1 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + |
| 778 | (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + |
| 779 | (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f; |
| 780 | out5.s2 = -((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + |
| 781 | (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + |
| 782 | (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 810.f; |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 783 | out5.s3 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f * |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 784 | (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f * |
| 785 | (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f; |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 786 | out5.s4 = ((16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f * |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 787 | (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + 16.f * |
| 788 | (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 16200.f; |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 789 | out5.s5 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f * |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 790 | (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + |
| 791 | (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f; |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 792 | out5.s6 = (16.f * (16.f * w00.s0 + 8.f * w10.s0 + 4.f * w20.s0 + 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 + 8.f * w10.s1 + 4.f * w20.s1 + 2.f * w30.s1 + w40.s1) + 4.f * |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 793 | (16.f * w00.s2 + 8.f * w10.s2 + 4.f * w20.s2 + 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 + 8.f * w10.s3 + 4.f * w20.s3 + 2.f * w30.s3 + w40.s3) + |
| 794 | (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41)) / 32400.f; |
| 795 | out5.s7 = (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) / 180.f; |
| 796 | |
| 797 | // Row 6 |
| 798 | out6.s0 = (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) / 180.f; |
| 799 | out6.s1 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + |
| 800 | (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + |
| 801 | (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f; |
| 802 | out6.s2 = -((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + |
| 803 | (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + |
| 804 | (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 810.f; |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 805 | out6.s3 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f * |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 806 | (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f * |
| 807 | (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f; |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 808 | out6.s4 = ((16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 2.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f * |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 809 | (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 8.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + 16.f * |
| 810 | (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 16200.f; |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 811 | out6.s5 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) + 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f * |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 812 | (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) + 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + |
| 813 | (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f; |
Gian Marco Iodice | 2213d4b | 2018-04-27 10:39:06 +0100 | [diff] [blame] | 814 | out6.s6 = (16.f * (16.f * w00.s0 - 8.f * w10.s0 + 4.f * w20.s0 - 2.f * w30.s0 + w40.s0) - 8.f * (16.f * w00.s1 - 8.f * w10.s1 + 4.f * w20.s1 - 2.f * w30.s1 + w40.s1) + 4.f * |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 815 | (16.f * w00.s2 - 8.f * w10.s2 + 4.f * w20.s2 - 2.f * w30.s2 + w40.s2) - 2.f * (16.f * w00.s3 - 8.f * w10.s3 + 4.f * w20.s3 - 2.f * w30.s3 + w40.s3) + |
| 816 | (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41)) / 32400.f; |
| 817 | out6.s7 = (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) / 180.f; |
| 818 | |
| 819 | // Row 7 |
| 820 | out7.s0 = w40.s0; |
| 821 | out7.s1 = -2.f * (w40.s0 + w40.s1 + w40.s2 + w40.s3 + w41) / 9.f; |
| 822 | out7.s2 = -2.f * (w40.s0 - w40.s1 + w40.s2 - w40.s3 + w41) / 9.f; |
| 823 | out7.s3 = (w40.s0 + 2.f * w40.s1 + 4.f * w40.s2 + 8.f * w40.s3 + 16.f * w41) / 90.f; |
| 824 | out7.s4 = (w40.s0 - 2.f * w40.s1 + 4.f * w40.s2 - 8.f * w40.s3 + 16.f * w41) / 90.f; |
| 825 | out7.s5 = (16.f * w40.s0 + 8.f * w40.s1 + 4.f * w40.s2 + 2.f * w40.s3 + w41) / 180.f; |
| 826 | out7.s6 = (16.f * w40.s0 - 8.f * w40.s1 + 4.f * w40.s2 - 2.f * w40.s3 + w41) / 180.f; |
| 827 | out7.s7 = w41; |
| 828 | |
| 829 | int z = get_global_id(2); |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 830 | int x0 = z / SRC_DIM_Z; // idx filter |
| 831 | int y0 = z % SRC_DIM_Z; // idx channel |
Giorgio Arena | 9373c8b | 2018-04-11 19:07:17 +0100 | [diff] [blame] | 832 | |
| 833 | // Get output address |
| 834 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * dst_stride_x + y0 * dst_stride_y; |
| 835 | |
| 836 | // Store the 64 values across the 64 channels |
| 837 | *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0; |
| 838 | *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1; |
| 839 | *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2; |
| 840 | *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3; |
| 841 | *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4; |
| 842 | *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5; |
| 843 | *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6; |
| 844 | *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7; |
| 845 | *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0; |
| 846 | *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1; |
| 847 | *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2; |
| 848 | *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3; |
| 849 | *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4; |
| 850 | *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5; |
| 851 | *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6; |
| 852 | *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7; |
| 853 | *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0; |
| 854 | *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1; |
| 855 | *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2; |
| 856 | *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3; |
| 857 | *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4; |
| 858 | *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5; |
| 859 | *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6; |
| 860 | *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7; |
| 861 | *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0; |
| 862 | *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1; |
| 863 | *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2; |
| 864 | *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3; |
| 865 | *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4; |
| 866 | *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5; |
| 867 | *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6; |
| 868 | *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7; |
| 869 | *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0; |
| 870 | *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1; |
| 871 | *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2; |
| 872 | *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3; |
| 873 | *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4; |
| 874 | *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5; |
| 875 | *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6; |
| 876 | *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7; |
| 877 | *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0; |
| 878 | *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1; |
| 879 | *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2; |
| 880 | *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3; |
| 881 | *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4; |
| 882 | *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5; |
| 883 | *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6; |
| 884 | *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7; |
| 885 | *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0; |
| 886 | *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1; |
| 887 | *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2; |
| 888 | *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3; |
| 889 | *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4; |
| 890 | *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5; |
| 891 | *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6; |
| 892 | *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7; |
| 893 | *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0; |
| 894 | *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1; |
| 895 | *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2; |
| 896 | *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3; |
| 897 | *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4; |
| 898 | *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5; |
| 899 | *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6; |
| 900 | *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7; |
| 901 | } |
Giorgio Arena | 80d65d8 | 2018-06-08 16:30:00 +0100 | [diff] [blame] | 902 | |
| 903 | /** This OpenCL kernel performs Winograd filter transform 5x5 when the data layout is NHWC and the output tile is 4x4 |
| 904 | * |
| 905 | * @note In order to correctly split the input tensor in batches, its dimension across the Z axis (channels for NCHW, height for NHWC) must be passed at compile time using -DSRC_DIM_Z: e.g. -DSRC_DIM_Z=64 |
| 906 | * |
| 907 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 908 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 909 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 910 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 911 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 912 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 913 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 914 | * @param[in] src_stride_w Stride of the source tensor in W dimension (in bytes) |
| 915 | * @param[in] src_step_w src_stride_w * number of elements along W processed per workitem(in bytes) |
| 916 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 917 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 918 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 919 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 920 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 921 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 922 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 923 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 924 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 925 | */ |
| 926 | __kernel void winograd_filter_transform_4x4_5x5_nhwc( |
| 927 | TENSOR4D_DECLARATION(src), |
| 928 | TENSOR3D_DECLARATION(dst)) |
| 929 | { |
| 930 | Tensor4D src = CONVERT_TO_TENSOR4D_STRUCT(src, SRC_DIM_Z); |
| 931 | |
| 932 | const __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + get_global_id(0) * sizeof(float) + get_global_id(1) * src_step_y + get_global_id(2) * src_step_w; |
| 933 | |
| 934 | // Load the values from the input tensor |
| 935 | float w00 = *((__global float *)(src_addr + 0 * src_stride_z + 0 * src_stride_y)); |
| 936 | float w01 = *((__global float *)(src_addr + 0 * src_stride_z + 1 * src_stride_y)); |
| 937 | float w02 = *((__global float *)(src_addr + 0 * src_stride_z + 2 * src_stride_y)); |
| 938 | float w03 = *((__global float *)(src_addr + 0 * src_stride_z + 3 * src_stride_y)); |
| 939 | float w04 = *((__global float *)(src_addr + 0 * src_stride_z + 4 * src_stride_y)); |
| 940 | float w10 = *((__global float *)(src_addr + 1 * src_stride_z + 0 * src_stride_y)); |
| 941 | float w11 = *((__global float *)(src_addr + 1 * src_stride_z + 1 * src_stride_y)); |
| 942 | float w12 = *((__global float *)(src_addr + 1 * src_stride_z + 2 * src_stride_y)); |
| 943 | float w13 = *((__global float *)(src_addr + 1 * src_stride_z + 3 * src_stride_y)); |
| 944 | float w14 = *((__global float *)(src_addr + 1 * src_stride_z + 4 * src_stride_y)); |
| 945 | float w20 = *((__global float *)(src_addr + 2 * src_stride_z + 0 * src_stride_y)); |
| 946 | float w21 = *((__global float *)(src_addr + 2 * src_stride_z + 1 * src_stride_y)); |
| 947 | float w22 = *((__global float *)(src_addr + 2 * src_stride_z + 2 * src_stride_y)); |
| 948 | float w23 = *((__global float *)(src_addr + 2 * src_stride_z + 3 * src_stride_y)); |
| 949 | float w24 = *((__global float *)(src_addr + 2 * src_stride_z + 4 * src_stride_y)); |
| 950 | float w30 = *((__global float *)(src_addr + 3 * src_stride_z + 0 * src_stride_y)); |
| 951 | float w31 = *((__global float *)(src_addr + 3 * src_stride_z + 1 * src_stride_y)); |
| 952 | float w32 = *((__global float *)(src_addr + 3 * src_stride_z + 2 * src_stride_y)); |
| 953 | float w33 = *((__global float *)(src_addr + 3 * src_stride_z + 3 * src_stride_y)); |
| 954 | float w34 = *((__global float *)(src_addr + 3 * src_stride_z + 4 * src_stride_y)); |
| 955 | float w40 = *((__global float *)(src_addr + 4 * src_stride_z + 0 * src_stride_y)); |
| 956 | float w41 = *((__global float *)(src_addr + 4 * src_stride_z + 1 * src_stride_y)); |
| 957 | float w42 = *((__global float *)(src_addr + 4 * src_stride_z + 2 * src_stride_y)); |
| 958 | float w43 = *((__global float *)(src_addr + 4 * src_stride_z + 3 * src_stride_y)); |
| 959 | float w44 = *((__global float *)(src_addr + 4 * src_stride_z + 4 * src_stride_y)); |
| 960 | |
| 961 | // Transform the 3x3 tile in a 8x8 tile |
| 962 | float8 out0 = 0.0f; |
| 963 | float8 out1 = 0.0f; |
| 964 | float8 out2 = 0.0f; |
| 965 | float8 out3 = 0.0f; |
| 966 | float8 out4 = 0.0f; |
| 967 | float8 out5 = 0.0f; |
| 968 | float8 out6 = 0.0f; |
| 969 | float8 out7 = 0.0f; |
| 970 | |
| 971 | // Row 0 |
| 972 | out0.s0 = w00; |
| 973 | out0.s1 = -2.f * (w00 + w01 + w02 + w03 + w04) / 9.f; |
| 974 | out0.s2 = -2.f * (w00 - w01 + w02 - w03 + w04) / 9.f; |
| 975 | out0.s3 = (w00 + 2.f * w01 + 4.f * w02 + 8.f * w03 + 16.f * w04) / 90.f; |
| 976 | out0.s4 = (w00 - 2.f * w01 + 4.f * w02 - 8.f * w03 + 16.f * w04) / 90.f; |
| 977 | out0.s5 = (16.f * w00 + 8.f * w01 + 4.f * w02 + 2.f * w03 + w04) / 180.f; |
| 978 | out0.s6 = (16.f * w00 - 8.f * w01 + 4.f * w02 - 2.f * w03 + w04) / 180.f; |
| 979 | out0.s7 = w04; |
| 980 | |
| 981 | // Row 1 |
| 982 | out1.s0 = -2.f * (w00 + w10 + w20 + w30 + w40) / 9.f; |
| 983 | out1.s1 = 4.f * ((w00 + w10 + w20 + w30 + w40) + (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) + (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f; |
| 984 | out1.s2 = 4.f * ((w00 + w10 + w20 + w30 + w40) - (w01 + w11 + w21 + w31 + w41) + (w02 + w12 + w22 + w32 + w42) - (w03 + w13 + w23 + w33 + w43) + (w04 + w14 + w24 + w34 + w44)) / 81.f; |
| 985 | out1.s3 = -((w00 + w10 + w20 + w30 + w40) + 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f * |
| 986 | (w04 + w14 + w24 + w34 + w44)) / 405.f; |
| 987 | out1.s4 = -((w00 + w10 + w20 + w30 + w40) - 2.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 8.f * (w03 + w13 + w23 + w33 + w43) + 16.f * |
| 988 | (w04 + w14 + w24 + w34 + w44)) / 405.f; |
| 989 | out1.s5 = -(16.f * (w00 + w10 + w20 + w30 + w40) + 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) + 2.f * (w03 + w13 + w23 + w33 + w43) + |
| 990 | (w04 + w14 + w24 + w34 + w44)) / 810.f; |
| 991 | out1.s6 = -(16.f * (w00 + w10 + w20 + w30 + w40) - 8.f * (w01 + w11 + w21 + w31 + w41) + 4.f * (w02 + w12 + w22 + w32 + w42) - 2.f * (w03 + w13 + w23 + w33 + w43) + |
| 992 | (w04 + w14 + w24 + w34 + w44)) / 810.f; |
| 993 | out1.s7 = -2.f * (w04 + w14 + w24 + w34 + w44) / 9.f; |
| 994 | |
| 995 | // Row 2 |
| 996 | out2.s0 = -2.f * (w00 - w10 + w20 - w30 + w40) / 9.f; |
| 997 | out2.s1 = 4.f * ((w00 - w10 + w20 - w30 + w40) + (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) + (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f; |
| 998 | out2.s2 = 4.f * ((w00 - w10 + w20 - w30 + w40) - (w01 - w11 + w21 - w31 + w41) + (w02 - w12 + w22 - w32 + w42) - (w03 - w13 + w23 - w33 + w43) + (w04 - w14 + w24 - w34 + w44)) / 81.f; |
| 999 | out2.s3 = -((w00 - w10 + w20 - w30 + w40) + 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f * |
| 1000 | (w04 - w14 + w24 - w34 + w44)) / 405.f; |
| 1001 | out2.s4 = -((w00 - w10 + w20 - w30 + w40) - 2.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 8.f * (w03 - w13 + w23 - w33 + w43) + 16.f * |
| 1002 | (w04 - w14 + w24 - w34 + w44)) / 405.f; |
| 1003 | out2.s5 = -(16.f * (w00 - w10 + w20 - w30 + w40) + 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) + 2.f * (w03 - w13 + w23 - w33 + w43) + |
| 1004 | (w04 - w14 + w24 - w34 + w44)) / 810.f; |
| 1005 | out2.s6 = -(16.f * (w00 - w10 + w20 - w30 + w40) - 8.f * (w01 - w11 + w21 - w31 + w41) + 4.f * (w02 - w12 + w22 - w32 + w42) - 2.f * (w03 - w13 + w23 - w33 + w43) + |
| 1006 | (w04 - w14 + w24 - w34 + w44)) / 810.f; |
| 1007 | out2.s7 = -2.f * (w04 - w14 + w24 - w34 + w44) / 9.f; |
| 1008 | |
| 1009 | // Row 3 |
| 1010 | out3.s0 = (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) / 90.f; |
| 1011 | out3.s1 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + |
| 1012 | (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f; |
| 1013 | out3.s2 = -((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - |
| 1014 | (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 405.f; |
| 1015 | out3.s3 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 8.f |
| 1016 | * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f; |
| 1017 | out3.s4 = ((w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 2.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 8.f |
| 1018 | * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + 16.f * (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 8100.f; |
| 1019 | out3.s5 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) + 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * |
| 1020 | (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) + 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f; |
| 1021 | out3.s6 = (16.f * (w00 + 2.f * w10 + 4.f * w20 + 8.f * w30 + 16.f * w40) - 8.f * (w01 + 2.f * w11 + 4.f * w21 + 8.f * w31 + 16.f * w41) + 4.f * |
| 1022 | (w02 + 2.f * w12 + 4.f * w22 + 8.f * w32 + 16.f * w42) - 2.f * (w03 + 2.f * w13 + 4.f * w23 + 8.f * w33 + 16.f * w43) + (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44)) / 16200.f; |
| 1023 | out3.s7 = (w04 + 2.f * w14 + 4.f * w24 + 8.f * w34 + 16.f * w44) / 90.f; |
| 1024 | |
| 1025 | // Row 4 |
| 1026 | out4.s0 = (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) / 90.f; |
| 1027 | out4.s1 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + |
| 1028 | (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f; |
| 1029 | out4.s2 = -((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - |
| 1030 | (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 405.f; |
| 1031 | out4.s3 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 8.f |
| 1032 | * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f; |
| 1033 | out4.s4 = ((w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 2.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 8.f |
| 1034 | * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + 16.f * (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 8100.f; |
| 1035 | out4.s5 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) + 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * |
| 1036 | (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) + 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f; |
| 1037 | out4.s6 = (16.f * (w00 - 2.f * w10 + 4.f * w20 - 8.f * w30 + 16.f * w40) - 8.f * (w01 - 2.f * w11 + 4.f * w21 - 8.f * w31 + 16.f * w41) + 4.f * |
| 1038 | (w02 - 2.f * w12 + 4.f * w22 - 8.f * w32 + 16.f * w42) - 2.f * (w03 - 2.f * w13 + 4.f * w23 - 8.f * w33 + 16.f * w43) + (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44)) / 16200.f; |
| 1039 | out4.s7 = (w04 - 2.f * w14 + 4.f * w24 - 8.f * w34 + 16.f * w44) / 90.f; |
| 1040 | |
| 1041 | // Row 5 |
| 1042 | out5.s0 = (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) / 180.f; |
| 1043 | out5.s1 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + |
| 1044 | (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f; |
| 1045 | out5.s2 = -((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - |
| 1046 | (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 810.f; |
| 1047 | out5.s3 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 8.f |
| 1048 | * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f; |
| 1049 | out5.s4 = ((16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 2.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 8.f |
| 1050 | * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + 16.f * (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 16200.f; |
| 1051 | out5.s5 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) + 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * |
| 1052 | (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) + 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f; |
| 1053 | out5.s6 = (16.f * (16.f * w00 + 8.f * w10 + 4.f * w20 + 2.f * w30 + w40) - 8.f * (16.f * w01 + 8.f * w11 + 4.f * w21 + 2.f * w31 + w41) + 4.f * |
| 1054 | (16.f * w02 + 8.f * w12 + 4.f * w22 + 2.f * w32 + w42) - 2.f * (16.f * w03 + 8.f * w13 + 4.f * w23 + 2.f * w33 + w43) + (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44)) / 32400.f; |
| 1055 | out5.s7 = (16.f * w04 + 8.f * w14 + 4.f * w24 + 2.f * w34 + w44) / 180.f; |
| 1056 | |
| 1057 | // Row 6 |
| 1058 | out6.s0 = (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) / 180.f; |
| 1059 | out6.s1 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + |
| 1060 | (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f; |
| 1061 | out6.s2 = -((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - |
| 1062 | (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 810.f; |
| 1063 | out6.s3 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 8.f |
| 1064 | * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f; |
| 1065 | out6.s4 = ((16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 2.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 8.f |
| 1066 | * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + 16.f * (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 16200.f; |
| 1067 | out6.s5 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) + 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * |
| 1068 | (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) + 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f; |
| 1069 | out6.s6 = (16.f * (16.f * w00 - 8.f * w10 + 4.f * w20 - 2.f * w30 + w40) - 8.f * (16.f * w01 - 8.f * w11 + 4.f * w21 - 2.f * w31 + w41) + 4.f * |
| 1070 | (16.f * w02 - 8.f * w12 + 4.f * w22 - 2.f * w32 + w42) - 2.f * (16.f * w03 - 8.f * w13 + 4.f * w23 - 2.f * w33 + w43) + (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44)) / 32400.f; |
| 1071 | out6.s7 = (16.f * w04 - 8.f * w14 + 4.f * w24 - 2.f * w34 + w44) / 180.f; |
| 1072 | |
| 1073 | // Row 7 |
| 1074 | out7.s0 = w40; |
| 1075 | out7.s1 = -2.f * (w40 + w41 + w42 + w43 + w44) / 9.f; |
| 1076 | out7.s2 = -2.f * (w40 - w41 + w42 - w43 + w44) / 9.f; |
| 1077 | out7.s3 = (w40 + 2.f * w41 + 4.f * w42 + 8.f * w43 + 16.f * w44) / 90.f; |
| 1078 | out7.s4 = (w40 - 2.f * w41 + 4.f * w42 - 8.f * w43 + 16.f * w44) / 90.f; |
| 1079 | out7.s5 = (16.f * w40 + 8.f * w41 + 4.f * w42 + 2.f * w43 + w44) / 180.f; |
| 1080 | out7.s6 = (16.f * w40 - 8.f * w41 + 4.f * w42 - 2.f * w43 + w44) / 180.f; |
| 1081 | out7.s7 = w44; |
| 1082 | |
| 1083 | int x0 = get_global_id(2); // idx filter |
| 1084 | int y0 = get_global_id(0); // idx channel |
| 1085 | |
| 1086 | // Get output address |
| 1087 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x0 * sizeof(float) + y0 * dst_stride_y; |
| 1088 | |
| 1089 | // Store the 64 values across the 64 channels |
| 1090 | *(__global float *)(dst_addr + 0 * dst_stride_z) = out0.s0; |
| 1091 | *(__global float *)(dst_addr + 1 * dst_stride_z) = out0.s1; |
| 1092 | *(__global float *)(dst_addr + 2 * dst_stride_z) = out0.s2; |
| 1093 | *(__global float *)(dst_addr + 3 * dst_stride_z) = out0.s3; |
| 1094 | *(__global float *)(dst_addr + 4 * dst_stride_z) = out0.s4; |
| 1095 | *(__global float *)(dst_addr + 5 * dst_stride_z) = out0.s5; |
| 1096 | *(__global float *)(dst_addr + 6 * dst_stride_z) = out0.s6; |
| 1097 | *(__global float *)(dst_addr + 7 * dst_stride_z) = out0.s7; |
| 1098 | *(__global float *)(dst_addr + 8 * dst_stride_z) = out1.s0; |
| 1099 | *(__global float *)(dst_addr + 9 * dst_stride_z) = out1.s1; |
| 1100 | *(__global float *)(dst_addr + 10 * dst_stride_z) = out1.s2; |
| 1101 | *(__global float *)(dst_addr + 11 * dst_stride_z) = out1.s3; |
| 1102 | *(__global float *)(dst_addr + 12 * dst_stride_z) = out1.s4; |
| 1103 | *(__global float *)(dst_addr + 13 * dst_stride_z) = out1.s5; |
| 1104 | *(__global float *)(dst_addr + 14 * dst_stride_z) = out1.s6; |
| 1105 | *(__global float *)(dst_addr + 15 * dst_stride_z) = out1.s7; |
| 1106 | *(__global float *)(dst_addr + 16 * dst_stride_z) = out2.s0; |
| 1107 | *(__global float *)(dst_addr + 17 * dst_stride_z) = out2.s1; |
| 1108 | *(__global float *)(dst_addr + 18 * dst_stride_z) = out2.s2; |
| 1109 | *(__global float *)(dst_addr + 19 * dst_stride_z) = out2.s3; |
| 1110 | *(__global float *)(dst_addr + 20 * dst_stride_z) = out2.s4; |
| 1111 | *(__global float *)(dst_addr + 21 * dst_stride_z) = out2.s5; |
| 1112 | *(__global float *)(dst_addr + 22 * dst_stride_z) = out2.s6; |
| 1113 | *(__global float *)(dst_addr + 23 * dst_stride_z) = out2.s7; |
| 1114 | *(__global float *)(dst_addr + 24 * dst_stride_z) = out3.s0; |
| 1115 | *(__global float *)(dst_addr + 25 * dst_stride_z) = out3.s1; |
| 1116 | *(__global float *)(dst_addr + 26 * dst_stride_z) = out3.s2; |
| 1117 | *(__global float *)(dst_addr + 27 * dst_stride_z) = out3.s3; |
| 1118 | *(__global float *)(dst_addr + 28 * dst_stride_z) = out3.s4; |
| 1119 | *(__global float *)(dst_addr + 29 * dst_stride_z) = out3.s5; |
| 1120 | *(__global float *)(dst_addr + 30 * dst_stride_z) = out3.s6; |
| 1121 | *(__global float *)(dst_addr + 31 * dst_stride_z) = out3.s7; |
| 1122 | *(__global float *)(dst_addr + 32 * dst_stride_z) = out4.s0; |
| 1123 | *(__global float *)(dst_addr + 33 * dst_stride_z) = out4.s1; |
| 1124 | *(__global float *)(dst_addr + 34 * dst_stride_z) = out4.s2; |
| 1125 | *(__global float *)(dst_addr + 35 * dst_stride_z) = out4.s3; |
| 1126 | *(__global float *)(dst_addr + 36 * dst_stride_z) = out4.s4; |
| 1127 | *(__global float *)(dst_addr + 37 * dst_stride_z) = out4.s5; |
| 1128 | *(__global float *)(dst_addr + 38 * dst_stride_z) = out4.s6; |
| 1129 | *(__global float *)(dst_addr + 39 * dst_stride_z) = out4.s7; |
| 1130 | *(__global float *)(dst_addr + 40 * dst_stride_z) = out5.s0; |
| 1131 | *(__global float *)(dst_addr + 41 * dst_stride_z) = out5.s1; |
| 1132 | *(__global float *)(dst_addr + 42 * dst_stride_z) = out5.s2; |
| 1133 | *(__global float *)(dst_addr + 43 * dst_stride_z) = out5.s3; |
| 1134 | *(__global float *)(dst_addr + 44 * dst_stride_z) = out5.s4; |
| 1135 | *(__global float *)(dst_addr + 45 * dst_stride_z) = out5.s5; |
| 1136 | *(__global float *)(dst_addr + 46 * dst_stride_z) = out5.s6; |
| 1137 | *(__global float *)(dst_addr + 47 * dst_stride_z) = out5.s7; |
| 1138 | *(__global float *)(dst_addr + 48 * dst_stride_z) = out6.s0; |
| 1139 | *(__global float *)(dst_addr + 49 * dst_stride_z) = out6.s1; |
| 1140 | *(__global float *)(dst_addr + 50 * dst_stride_z) = out6.s2; |
| 1141 | *(__global float *)(dst_addr + 51 * dst_stride_z) = out6.s3; |
| 1142 | *(__global float *)(dst_addr + 52 * dst_stride_z) = out6.s4; |
| 1143 | *(__global float *)(dst_addr + 53 * dst_stride_z) = out6.s5; |
| 1144 | *(__global float *)(dst_addr + 54 * dst_stride_z) = out6.s6; |
| 1145 | *(__global float *)(dst_addr + 55 * dst_stride_z) = out6.s7; |
| 1146 | *(__global float *)(dst_addr + 56 * dst_stride_z) = out7.s0; |
| 1147 | *(__global float *)(dst_addr + 57 * dst_stride_z) = out7.s1; |
| 1148 | *(__global float *)(dst_addr + 58 * dst_stride_z) = out7.s2; |
| 1149 | *(__global float *)(dst_addr + 59 * dst_stride_z) = out7.s3; |
| 1150 | *(__global float *)(dst_addr + 60 * dst_stride_z) = out7.s4; |
| 1151 | *(__global float *)(dst_addr + 61 * dst_stride_z) = out7.s5; |
| 1152 | *(__global float *)(dst_addr + 62 * dst_stride_z) = out7.s6; |
| 1153 | *(__global float *)(dst_addr + 63 * dst_stride_z) = out7.s7; |
| 1154 | } |
Giorgio Arena | dcb5b28 | 2018-04-25 12:07:29 +0100 | [diff] [blame] | 1155 | #endif // defined(SRC_DIM_Z) |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 1156 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1157 | #if defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H) |
| 1158 | /** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3 and the output tile is 2x2/2x1 or 1x2 |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1159 | * |
| 1160 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 1161 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1162 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2 |
| 1163 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2 |
| 1164 | * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 1165 | * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1166 | * |
| 1167 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 1168 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 1169 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 1170 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 1171 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 1172 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 1173 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 1174 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 1175 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 1176 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 1177 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 1178 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 1179 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 1180 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 1181 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 1182 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 1183 | */ |
| 1184 | __kernel void winograd_input_transform_2x2_3x3_stepz1_nchw( |
| 1185 | TENSOR3D_DECLARATION(src), |
| 1186 | TENSOR3D_DECLARATION(dst)) |
| 1187 | { |
| 1188 | int x = get_global_id(0); |
| 1189 | int y = get_global_id(1); |
| 1190 | int z = get_global_id(2); |
| 1191 | |
| 1192 | // Compute input address |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1193 | __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z; |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1194 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1195 | src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y); |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1196 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1197 | #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) |
| 1198 | float4 in_row0 = vload4(0, (__global float *)(src_addr)); |
| 1199 | #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) |
| 1200 | float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)), |
| 1201 | *((__global float *)(src_addr + 1 * src_stride_y)), |
| 1202 | *((__global float *)(src_addr + 2 * src_stride_y)), |
| 1203 | *((__global float *)(src_addr + 3 * src_stride_y))); |
| 1204 | #else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1205 | float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y)); |
| 1206 | float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y)); |
| 1207 | float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y)); |
| 1208 | float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y)); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1209 | #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1210 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1211 | float4 tmp0 = in_row0; |
| 1212 | |
| 1213 | #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 1214 | tmp0 -= in_row2; |
| 1215 | #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1216 | |
| 1217 | float out00 = tmp0.s0 - tmp0.s2; |
| 1218 | float out01 = tmp0.s1 + tmp0.s2; |
| 1219 | float out02 = tmp0.s2 - tmp0.s1; |
| 1220 | float out03 = tmp0.s1 - tmp0.s3; |
| 1221 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1222 | #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 1223 | float4 tmp1 = in_row1 + in_row2; |
| 1224 | float4 tmp2 = in_row2 - in_row1; |
| 1225 | float4 tmp3 = in_row1 - in_row3; |
| 1226 | |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1227 | float out10 = tmp1.s0 - tmp1.s2; |
| 1228 | float out11 = tmp1.s1 + tmp1.s2; |
| 1229 | float out12 = tmp1.s2 - tmp1.s1; |
| 1230 | float out13 = tmp1.s1 - tmp1.s3; |
| 1231 | |
| 1232 | float out20 = tmp2.s0 - tmp2.s2; |
| 1233 | float out21 = tmp2.s1 + tmp2.s2; |
| 1234 | float out22 = tmp2.s2 - tmp2.s1; |
| 1235 | float out23 = tmp2.s1 - tmp2.s3; |
| 1236 | |
| 1237 | float out30 = tmp3.s0 - tmp3.s2; |
| 1238 | float out31 = tmp3.s1 + tmp3.s2; |
| 1239 | float out32 = tmp3.s2 - tmp3.s1; |
| 1240 | float out33 = tmp3.s1 - tmp3.s3; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1241 | #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1242 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1243 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y; |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1244 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1245 | *((__global float *)(dst_addr + 0 * dst_stride_z)) = out00; // in_row0.s0; out00; |
| 1246 | *((__global float *)(dst_addr + 1 * dst_stride_z)) = out01; // in_row0.s1; out01; |
| 1247 | *((__global float *)(dst_addr + 2 * dst_stride_z)) = out02; // in_row0.s2; out02; |
| 1248 | *((__global float *)(dst_addr + 3 * dst_stride_z)) = out03; // in_row0.s3; out03; |
| 1249 | |
| 1250 | #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1251 | *((__global float *)(dst_addr + 4 * dst_stride_z)) = out10; |
| 1252 | *((__global float *)(dst_addr + 5 * dst_stride_z)) = out11; |
| 1253 | *((__global float *)(dst_addr + 6 * dst_stride_z)) = out12; |
| 1254 | *((__global float *)(dst_addr + 7 * dst_stride_z)) = out13; |
| 1255 | *((__global float *)(dst_addr + 8 * dst_stride_z)) = out20; |
| 1256 | *((__global float *)(dst_addr + 9 * dst_stride_z)) = out21; |
| 1257 | *((__global float *)(dst_addr + 10 * dst_stride_z)) = out22; |
| 1258 | *((__global float *)(dst_addr + 11 * dst_stride_z)) = out23; |
| 1259 | *((__global float *)(dst_addr + 12 * dst_stride_z)) = out30; |
| 1260 | *((__global float *)(dst_addr + 13 * dst_stride_z)) = out31; |
| 1261 | *((__global float *)(dst_addr + 14 * dst_stride_z)) = out32; |
| 1262 | *((__global float *)(dst_addr + 15 * dst_stride_z)) = out33; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1263 | #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1264 | } |
| 1265 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1266 | /** This OpenCL kernel computes the input transform when the kernel size is 3x3/3x1 or 1x3, the output tile is 2x2/2x1 or 1x2 and the number of channels is multiple of 2 |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1267 | * |
| 1268 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 1269 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1270 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2 |
| 1271 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2 |
| 1272 | * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 1273 | * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1274 | * |
| 1275 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 1276 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 1277 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 1278 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 1279 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 1280 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 1281 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 1282 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 1283 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 1284 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 1285 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 1286 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 1287 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 1288 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 1289 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 1290 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 1291 | */ |
| 1292 | __kernel void winograd_input_transform_2x2_3x3_stepz2_nchw( |
| 1293 | TENSOR3D_DECLARATION(src), |
| 1294 | TENSOR3D_DECLARATION(dst)) |
| 1295 | { |
| 1296 | int x = get_global_id(0); |
| 1297 | int y = get_global_id(1); |
| 1298 | int z = get_global_id(2) * 2; |
| 1299 | |
| 1300 | // Compute input address |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1301 | __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z; |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1302 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1303 | src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y); |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1304 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1305 | #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) |
| 1306 | float4 in_row0 = vload4(0, (__global float *)(src_addr)); |
| 1307 | #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) |
| 1308 | float4 in_row0 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)), |
| 1309 | *((__global float *)(src_addr + 1 * src_stride_y)), |
| 1310 | *((__global float *)(src_addr + 2 * src_stride_y)), |
| 1311 | *((__global float *)(src_addr + 3 * src_stride_y))); |
| 1312 | #else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1313 | float4 in_row0 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y)); |
| 1314 | float4 in_row1 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y)); |
| 1315 | float4 in_row2 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y)); |
| 1316 | float4 in_row3 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y)); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1317 | #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1318 | |
| 1319 | src_addr += src_stride_z; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1320 | #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) |
| 1321 | float4 in_row4 = vload4(0, (__global float *)(src_addr)); |
| 1322 | #elif defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) // !defined(WINOGRAD_FILTER_TRANSFORM_HORIZONTAL) |
| 1323 | float4 in_row4 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)), |
| 1324 | *((__global float *)(src_addr + 1 * src_stride_y)), |
| 1325 | *((__global float *)(src_addr + 2 * src_stride_y)), |
| 1326 | *((__global float *)(src_addr + 3 * src_stride_y))); |
| 1327 | #else // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1328 | float4 in_row4 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y)); |
| 1329 | float4 in_row5 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y)); |
| 1330 | float4 in_row6 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y)); |
| 1331 | float4 in_row7 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y)); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1332 | #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1333 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1334 | float4 tmp0 = in_row0; |
| 1335 | float4 tmp4 = in_row4; |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1336 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1337 | #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 1338 | tmp0 -= in_row2; |
| 1339 | tmp4 -= in_row6; |
| 1340 | #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1341 | |
| 1342 | float2 out00 = (float2)(tmp0.s0 - tmp0.s2, tmp4.s0 - tmp4.s2); |
| 1343 | float2 out01 = (float2)(tmp0.s1 + tmp0.s2, tmp4.s1 + tmp4.s2); |
| 1344 | float2 out02 = (float2)(tmp0.s2 - tmp0.s1, tmp4.s2 - tmp4.s1); |
| 1345 | float2 out03 = (float2)(tmp0.s1 - tmp0.s3, tmp4.s1 - tmp4.s3); |
| 1346 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1347 | #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 1348 | float4 tmp1 = in_row1 + in_row2; |
| 1349 | float4 tmp2 = in_row2 - in_row1; |
| 1350 | float4 tmp3 = in_row1 - in_row3; |
| 1351 | |
| 1352 | float4 tmp5 = in_row5 + in_row6; |
| 1353 | float4 tmp6 = in_row6 - in_row5; |
| 1354 | float4 tmp7 = in_row5 - in_row7; |
| 1355 | |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1356 | float2 out10 = (float2)(tmp1.s0 - tmp1.s2, tmp5.s0 - tmp5.s2); |
| 1357 | float2 out11 = (float2)(tmp1.s1 + tmp1.s2, tmp5.s1 + tmp5.s2); |
| 1358 | float2 out12 = (float2)(tmp1.s2 - tmp1.s1, tmp5.s2 - tmp5.s1); |
| 1359 | float2 out13 = (float2)(tmp1.s1 - tmp1.s3, tmp5.s1 - tmp5.s3); |
| 1360 | |
| 1361 | float2 out20 = (float2)(tmp2.s0 - tmp2.s2, tmp6.s0 - tmp6.s2); |
| 1362 | float2 out21 = (float2)(tmp2.s1 + tmp2.s2, tmp6.s1 + tmp6.s2); |
| 1363 | float2 out22 = (float2)(tmp2.s2 - tmp2.s1, tmp6.s2 - tmp6.s1); |
| 1364 | float2 out23 = (float2)(tmp2.s1 - tmp2.s3, tmp6.s1 - tmp6.s3); |
| 1365 | |
| 1366 | float2 out30 = (float2)(tmp3.s0 - tmp3.s2, tmp7.s0 - tmp7.s2); |
| 1367 | float2 out31 = (float2)(tmp3.s1 + tmp3.s2, tmp7.s1 + tmp7.s2); |
| 1368 | float2 out32 = (float2)(tmp3.s2 - tmp3.s1, tmp7.s2 - tmp7.s1); |
| 1369 | float2 out33 = (float2)(tmp3.s1 - tmp3.s3, tmp7.s1 - tmp7.s3); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1370 | #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1371 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1372 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y; |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1373 | |
| 1374 | vstore2(out00, 0, (__global float *)(dst_addr + 0 * dst_stride_z)); |
| 1375 | vstore2(out01, 0, (__global float *)(dst_addr + 1 * dst_stride_z)); |
| 1376 | vstore2(out02, 0, (__global float *)(dst_addr + 2 * dst_stride_z)); |
| 1377 | vstore2(out03, 0, (__global float *)(dst_addr + 3 * dst_stride_z)); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1378 | |
| 1379 | #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1380 | vstore2(out10, 0, (__global float *)(dst_addr + 4 * dst_stride_z)); |
| 1381 | vstore2(out11, 0, (__global float *)(dst_addr + 5 * dst_stride_z)); |
| 1382 | vstore2(out12, 0, (__global float *)(dst_addr + 6 * dst_stride_z)); |
| 1383 | vstore2(out13, 0, (__global float *)(dst_addr + 7 * dst_stride_z)); |
| 1384 | vstore2(out20, 0, (__global float *)(dst_addr + 8 * dst_stride_z)); |
| 1385 | vstore2(out21, 0, (__global float *)(dst_addr + 9 * dst_stride_z)); |
| 1386 | vstore2(out22, 0, (__global float *)(dst_addr + 10 * dst_stride_z)); |
| 1387 | vstore2(out23, 0, (__global float *)(dst_addr + 11 * dst_stride_z)); |
| 1388 | vstore2(out30, 0, (__global float *)(dst_addr + 12 * dst_stride_z)); |
| 1389 | vstore2(out31, 0, (__global float *)(dst_addr + 13 * dst_stride_z)); |
| 1390 | vstore2(out32, 0, (__global float *)(dst_addr + 14 * dst_stride_z)); |
| 1391 | vstore2(out33, 0, (__global float *)(dst_addr + 15 * dst_stride_z)); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1392 | #endif // !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Giorgio Arena | 1f9ca1d | 2018-03-01 11:13:45 +0000 | [diff] [blame] | 1393 | } |
Giorgio Arena | fe5ef38 | 2018-04-17 10:14:10 +0100 | [diff] [blame] | 1394 | |
Giorgio Arena | c42f28d | 2018-04-26 11:33:05 +0100 | [diff] [blame] | 1395 | /** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data layout is NCHW |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1396 | * |
| 1397 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 1398 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1399 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2 |
| 1400 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2 |
| 1401 | * @note If this kernel is used to perform Winograd input transform 3x1, -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 1402 | * @note If this kernel is used to perform Winograd input transform 1x3, -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1403 | * |
| 1404 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 1405 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 1406 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 1407 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 1408 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 1409 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 1410 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 1411 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 1412 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 1413 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 1414 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 1415 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 1416 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 1417 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 1418 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 1419 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 1420 | */ |
| 1421 | __kernel void winograd_input_transform_4x4_3x3_stepz1_nchw( |
| 1422 | TENSOR3D_DECLARATION(src), |
| 1423 | TENSOR3D_DECLARATION(dst)) |
| 1424 | { |
| 1425 | int x = get_global_id(0); |
| 1426 | int y = get_global_id(1); |
| 1427 | int z = get_global_id(2); |
| 1428 | |
| 1429 | // Compute input address |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1430 | __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * OUTPUT_TILE_W * sizeof(float) + y * OUTPUT_TILE_H * src_stride_y + z * src_stride_z; |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1431 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1432 | src_addr = src_addr - ((int)PAD_LEFT * sizeof(float)) - ((int)PAD_TOP * src_stride_y); |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1433 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1434 | #if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 1435 | // Row0 |
| 1436 | float4 d00 = (float4)(*((__global float *)(src_addr + 0 * src_stride_y)), |
| 1437 | *((__global float *)(src_addr + 1 * src_stride_y)), |
| 1438 | *((__global float *)(src_addr + 2 * src_stride_y)), |
| 1439 | *((__global float *)(src_addr + 3 * src_stride_y))); |
| 1440 | float2 d01 = (float2)(*((__global float *)(src_addr + 4 * src_stride_y)), |
| 1441 | *((__global float *)(src_addr + 5 * src_stride_y))); |
| 1442 | #else // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 1443 | // Row0 |
| 1444 | float4 d00 = vload4(0, (__global float *)(src_addr + 0 * src_stride_y)); |
| 1445 | float2 d01 = vload2(2, (__global float *)(src_addr + 0 * src_stride_y)); |
| 1446 | #endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 1447 | |
| 1448 | float out0 = 0.0f; |
| 1449 | float out1 = 0.0f; |
| 1450 | float out2 = 0.0f; |
| 1451 | float out3 = 0.0f; |
| 1452 | float out4 = 0.0f; |
| 1453 | float out5 = 0.0f; |
| 1454 | |
| 1455 | // Channels [0, 5]: [out00, out01, out02, out03, out04, out05] |
| 1456 | out0 += 16.0f * d00.s0 - 20.0f * d00.s2 + 4.0f * d01.s0; |
| 1457 | out1 += -16.0f * d00.s1 - 16.0f * d00.s2 + 4.0f * d00.s3 + 4.0f * d01.s0; |
| 1458 | out2 += 16.0f * d00.s1 - 16.0f * d00.s2 - 4.0f * d00.s3 + 4.0f * d01.s0; |
| 1459 | out3 += -8.0f * d00.s1 - 4.0f * d00.s2 + 8.0f * d00.s3 + 4.0f * d01.s0; |
| 1460 | out4 += 8.0f * d00.s1 - 4.0f * d00.s2 - 8.0f * d00.s3 + 4.0f * d01.s0; |
| 1461 | out5 += 16.0f * d00.s1 - 20.0f * d00.s3 + 4.0f * d01.s1; |
| 1462 | |
| 1463 | #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1464 | // Row4 |
| 1465 | float4 d40 = vload4(0, (__global float *)(src_addr + 4 * src_stride_y)); |
| 1466 | float2 d41 = vload2(2, (__global float *)(src_addr + 4 * src_stride_y)); |
| 1467 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1468 | // k0, k1, k2, k3, k4, k5 are common terms for row0, row1, row2, row3 and row4 |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1469 | float k0 = d41.s0; |
| 1470 | float k1 = d41.s0; |
| 1471 | float k2 = d41.s0; |
| 1472 | float k3 = d41.s0; |
| 1473 | float k4 = d41.s0; |
| 1474 | float k5 = 0.0f; |
| 1475 | |
| 1476 | k0 += 4.0f * d40.s0 - 5.0f * d40.s2; |
| 1477 | k1 += -4.0f * d40.s1 - 4.0f * d40.s2 + d40.s3; |
| 1478 | k2 += 4.0f * d40.s1 - 4.0f * d40.s2 - d40.s3; |
| 1479 | k3 += -2.0f * d40.s1 + 2.0f * d40.s3 - d40.s2; |
| 1480 | k4 += 2.0f * d40.s1 - 2.0f * d40.s3 - d40.s2; |
| 1481 | k5 += 4.0f * d40.s1 - 5.0f * d40.s3 + d41.s1; |
| 1482 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1483 | out0 += k0; |
| 1484 | out1 += k1; |
| 1485 | out2 += k2; |
| 1486 | out3 += k3; |
| 1487 | out4 += k4; |
| 1488 | out5 += k5; |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1489 | |
| 1490 | // Row2 |
| 1491 | float4 d20 = vload4(0, (__global float *)(src_addr + 2 * src_stride_y)); |
| 1492 | float2 d21 = vload2(2, (__global float *)(src_addr + 2 * src_stride_y)); |
| 1493 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1494 | out0 += -20.0f * d20.s0 + 25.0f * d20.s2 - 5.0f * d21.s0; |
| 1495 | out1 += +20.0f * d20.s1 + 20.0f * d20.s2 - 5.0f * d20.s3 - 5.0f * d21.s0; |
| 1496 | out2 += -20.0f * d20.s1 + 20.0f * d20.s2 + 5.0f * d20.s3 - 5.0f * d21.s0; |
| 1497 | out3 += +10.0f * d20.s1 + 5.0f * d20.s2 - 10.0f * d20.s3 - 5.0f * d21.s0; |
| 1498 | out4 += -10.0f * d20.s1 + 5.0f * d20.s2 + 10.0f * d20.s3 - 5.0f * d21.s0; |
| 1499 | out5 += -20.0f * d20.s1 + 25.0f * d20.s3 - 5.0f * d21.s1; |
| 1500 | #endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 1501 | |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1502 | // Compute destination address |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1503 | __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + z * sizeof(float) + (x + y * (int)NUM_TILES_X) * dst_stride_y); |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1504 | |
| 1505 | uint dst_plane_stride = dst_stride_z / sizeof(float); |
| 1506 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1507 | *(dst_addr) = out0; |
| 1508 | dst_addr += dst_plane_stride; |
| 1509 | *(dst_addr) = out1; |
| 1510 | dst_addr += dst_plane_stride; |
| 1511 | *(dst_addr) = out2; |
| 1512 | dst_addr += dst_plane_stride; |
| 1513 | *(dst_addr) = out3; |
| 1514 | dst_addr += dst_plane_stride; |
| 1515 | *(dst_addr) = out4; |
| 1516 | dst_addr += dst_plane_stride; |
| 1517 | *(dst_addr) = out5; |
| 1518 | dst_addr += dst_plane_stride; |
| 1519 | |
| 1520 | #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1521 | float out6 = k0; |
| 1522 | float out7 = k1; |
| 1523 | float out8 = k2; |
| 1524 | float out9 = k3; |
| 1525 | float out10 = k4; |
| 1526 | float out11 = k5; |
| 1527 | float out12 = k0; |
| 1528 | float out13 = k1; |
| 1529 | float out14 = k2; |
| 1530 | float out15 = k3; |
| 1531 | float out16 = k4; |
| 1532 | float out17 = k5; |
| 1533 | float out18 = k0; |
| 1534 | float out19 = k1; |
| 1535 | float out20 = k2; |
| 1536 | float out21 = k3; |
| 1537 | float out22 = k4; |
| 1538 | float out23 = k5; |
| 1539 | float out24 = k0; |
| 1540 | float out25 = k1; |
| 1541 | float out26 = k2; |
| 1542 | float out27 = k3; |
| 1543 | float out28 = k4; |
| 1544 | float out29 = k5; |
| 1545 | |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1546 | // Row1 |
| 1547 | float4 d10 = vload4(0, (__global float *)(src_addr + 1 * src_stride_y)); |
| 1548 | float2 d11 = vload2(2, (__global float *)(src_addr + 1 * src_stride_y)); |
| 1549 | |
| 1550 | // Row3 |
| 1551 | float4 d30 = vload4(0, (__global float *)(src_addr + 3 * src_stride_y)); |
| 1552 | float2 d31 = vload2(2, (__global float *)(src_addr + 3 * src_stride_y)); |
| 1553 | |
| 1554 | // Compute common parts for the channels between [6, 29] |
| 1555 | // Channels [6, 11]: [out10, out11, out12, out13, out14, out15] |
| 1556 | // Channels [12, 17]: [out20, out21, out22, out23, out24, out25] |
| 1557 | float part0 = -16.0f * d20.s0 + 20.0f * d20.s2 - 4.0f * d21.s0; |
| 1558 | float part1 = 16.0f * d10.s0 - 20.0f * d10.s2 + 4.0f * d11.s0 - 4.0f * d30.s0 + 5.0f * d30.s2 - d31.s0; |
| 1559 | float part2 = 16.0f * d20.s2 - 4.0f * d21.s0; |
| 1560 | float part3 = 16.0f * d20.s1 - 4.0f * d20.s3; |
| 1561 | float part4 = 16.0f * d10.s2 - 4.0f * d11.s0 - 4.0f * d30.s2 + d31.s0; |
| 1562 | float part5 = 16.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + d30.s3; |
| 1563 | float part6 = 4.0f * d20.s2 - 4.0f * d21.s0; |
| 1564 | float part7 = 8.0f * d10.s1 - 8.0f * d10.s3 - 2.0f * d30.s1 + 2.0f * d30.s3; |
| 1565 | float part8 = 4.0f * d10.s2 - 4.0f * d11.s0 - d30.s2 + d31.s0; |
| 1566 | float part9 = 8.0f * d20.s1 - 8.0f * d20.s3; |
| 1567 | float part10 = -16.0f * d20.s1 + 20.0f * d20.s3 - 4.0f * d21.s1; |
| 1568 | float part11 = -16.0f * d10.s1 + 20.0f * d10.s3 - 4.0f * d11.s1 + 4.0f * d30.s1 - 5.0f * d30.s3 + d31.s1; |
| 1569 | |
| 1570 | // Channels [18, 23]: [out30, out31, out32, out33, out34, out35] |
| 1571 | // Channels [24, 29]: [out40, out41, out42, out43, out44, out45] |
| 1572 | float part12 = 8.0f * d10.s0 - 10.0f * d10.s2 + 2.0f * d11.s0 - 8.0f * d30.s0 + 10.0f * d30.s2 - 2.0f * d31.s0; |
| 1573 | float part13 = part0 * 0.25f; // -4.0f * d20.s0 + 5.0f * d20.s2 - d21.s0 |
| 1574 | float part14 = part2 * 0.25f; // 4.0f * d20.s2 - d21.s0 |
| 1575 | float part15 = 8.0f * d10.s1 - 2.0f * d10.s3 - 8.0f * d30.s1 + 2.0f * d30.s3; |
| 1576 | float part16 = 8.0f * d10.s2 - 2.0f * d11.s0 - 8.0f * d30.s2 + 2.0f * d31.s0; |
| 1577 | float part17 = part3 * 0.25f; // 4.0f * d20.s1 - d20.s3 |
| 1578 | float part18 = part6 * 0.25f; // d20.s2 - d21.s0 |
| 1579 | float part19 = 4.0f * d10.s1 - 4.0f * d10.s3 - 4.0f * d30.s1 + 4.0f * d30.s3; |
| 1580 | float part20 = 2.0f * d10.s2 - 2.0f * d11.s0 - 2.0f * d30.s2 + 2.0f * d31.s0; |
| 1581 | float part21 = part9 * 0.25f; // 2.0f * (d20.s1 - d20.s3) |
| 1582 | float part22 = part10 * 0.25f; // - 4.0f * d20.s1 + 5.0f * d20.s3 - d21.s1 |
| 1583 | float part23 = part11 * 0.5f + 6.0f * d30.s1 - 7.5f * d30.s3 + 1.5f * d31.s1; // - 8.0f * d10.s1 + 10.0f * d10.s3 - 2.0f * d11.s1 + 8.0f * d30.s1 - 10.0f * d30.s3 + 2.0f * d31.s1; |
| 1584 | |
| 1585 | out6 += part0 - part1; |
| 1586 | out12 += part0 + part1; |
| 1587 | out7 += part2 + part3 + part4 + part5; |
| 1588 | out8 += part2 - part3 + part4 - part5; |
| 1589 | out13 += part2 + part3 - part4 - part5; |
| 1590 | out14 += part2 - part3 - part4 + part5; |
| 1591 | out9 += part6 + part7 + part8 + part9; |
| 1592 | out10 += part6 - part7 + part8 - part9; |
| 1593 | out15 += part6 - part7 - part8 + part9; |
| 1594 | out16 += part6 + part7 - part8 - part9; |
| 1595 | out11 += part10 + part11; |
| 1596 | out17 += part10 - part11; |
| 1597 | |
| 1598 | out18 += part13 - part12; |
| 1599 | out24 += part13 + part12; |
| 1600 | out19 += part14 + part15 + part16 + part17; |
| 1601 | out20 += part14 - part15 + part16 - part17; |
| 1602 | out25 += part14 - part15 - part16 + part17; |
| 1603 | out26 += part14 + part15 - part16 - part17; |
| 1604 | out21 += part18 + part19 + part20 + part21; |
| 1605 | out22 += part18 - part19 + part20 - part21; |
| 1606 | out27 += part18 - part19 - part20 + part21; |
| 1607 | out28 += part18 + part19 - part20 - part21; |
| 1608 | out23 += part22 + part23; |
| 1609 | out29 += part22 - part23; |
| 1610 | |
| 1611 | *(dst_addr) = out6; |
| 1612 | dst_addr += dst_plane_stride; |
| 1613 | *(dst_addr) = out7; |
| 1614 | dst_addr += dst_plane_stride; |
| 1615 | *(dst_addr) = out8; |
| 1616 | dst_addr += dst_plane_stride; |
| 1617 | *(dst_addr) = out9; |
| 1618 | dst_addr += dst_plane_stride; |
| 1619 | *(dst_addr) = out10; |
| 1620 | dst_addr += dst_plane_stride; |
| 1621 | *(dst_addr) = out11; |
| 1622 | dst_addr += dst_plane_stride; |
| 1623 | *(dst_addr) = out12; |
| 1624 | dst_addr += dst_plane_stride; |
| 1625 | *(dst_addr) = out13; |
| 1626 | dst_addr += dst_plane_stride; |
| 1627 | *(dst_addr) = out14; |
| 1628 | dst_addr += dst_plane_stride; |
| 1629 | *(dst_addr) = out15; |
| 1630 | dst_addr += dst_plane_stride; |
| 1631 | *(dst_addr) = out16; |
| 1632 | dst_addr += dst_plane_stride; |
| 1633 | *(dst_addr) = out17; |
| 1634 | dst_addr += dst_plane_stride; |
| 1635 | |
| 1636 | *(dst_addr) = out18; |
| 1637 | dst_addr += dst_plane_stride; |
| 1638 | *(dst_addr) = out19; |
| 1639 | dst_addr += dst_plane_stride; |
| 1640 | *(dst_addr) = out20; |
| 1641 | dst_addr += dst_plane_stride; |
| 1642 | *(dst_addr) = out21; |
| 1643 | dst_addr += dst_plane_stride; |
| 1644 | *(dst_addr) = out22; |
| 1645 | dst_addr += dst_plane_stride; |
| 1646 | *(dst_addr) = out23; |
| 1647 | dst_addr += dst_plane_stride; |
| 1648 | *(dst_addr) = out24; |
| 1649 | dst_addr += dst_plane_stride; |
| 1650 | *(dst_addr) = out25; |
| 1651 | dst_addr += dst_plane_stride; |
| 1652 | *(dst_addr) = out26; |
| 1653 | dst_addr += dst_plane_stride; |
| 1654 | *(dst_addr) = out27; |
| 1655 | dst_addr += dst_plane_stride; |
| 1656 | *(dst_addr) = out28; |
| 1657 | dst_addr += dst_plane_stride; |
| 1658 | *(dst_addr) = out29; |
| 1659 | dst_addr += dst_plane_stride; |
| 1660 | |
| 1661 | // Row5 |
| 1662 | float4 d50 = vload4(0, (__global float *)(src_addr + 5 * src_stride_y)); |
| 1663 | float2 d51 = vload2(2, (__global float *)(src_addr + 5 * src_stride_y)); |
| 1664 | |
| 1665 | // Channels [30, 35] |
| 1666 | out0 = 16.0f * d10.s0 - 20.0f * d10.s2 - 20.0f * d30.s0 + 25.0f * d30.s2 + 4.0f * d50.s0 - 5.0f * d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0; |
| 1667 | out1 = -16.0f * d10.s1 - 16.0f * d10.s2 + 4.0f * d10.s3 + 20.0f * d30.s1 + 20.0f * d30.s2 - 5.0f * d30.s3 - 4.0f * d50.s1 - 4.0f * d50.s2 + d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0; |
| 1668 | out2 = 16.0f * d10.s1 - 16.0f * d10.s2 - 4.0f * d10.s3 - 20.0f * d30.s1 + 20.0f * d30.s2 + 5.0f * d30.s3 + 4.0f * d50.s1 - 4.0f * d50.s2 - d50.s3 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0; |
| 1669 | out3 = -8.0f * d10.s1 - 4.0f * d10.s2 + 8.0f * d10.s3 + 10.0f * d30.s1 - 10.0f * d30.s3 + 5.0f * d30.s2 - 2.0f * d50.s1 + 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0; |
| 1670 | out4 = 8.0f * d10.s1 - 4.0f * d10.s2 - 8.0f * d10.s3 - 10.0f * d30.s1 + 5.0f * d30.s2 + 10.0f * d30.s3 + 2.0f * d50.s1 - 2.0f * d50.s3 - d50.s2 + d51.s0 + 4.0f * d11.s0 - 5.0f * d31.s0; |
| 1671 | out5 = 16.0f * d10.s1 - 20.0f * d10.s3 + 4.0f * d11.s1 - 20.0f * d30.s1 + 25.0f * d30.s3 - 5.0f * d31.s1 + 4.0f * d50.s1 - 5.0f * d50.s3 + d51.s1; |
| 1672 | |
| 1673 | *(dst_addr) = out0; |
| 1674 | dst_addr += dst_plane_stride; |
| 1675 | *(dst_addr) = out1; |
| 1676 | dst_addr += dst_plane_stride; |
| 1677 | *(dst_addr) = out2; |
| 1678 | dst_addr += dst_plane_stride; |
| 1679 | *(dst_addr) = out3; |
| 1680 | dst_addr += dst_plane_stride; |
| 1681 | *(dst_addr) = out4; |
| 1682 | dst_addr += dst_plane_stride; |
| 1683 | *(dst_addr) = out5; |
| 1684 | dst_addr += dst_plane_stride; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 1685 | #endif // #if !defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 1686 | } |
| 1687 | |
Giorgio Arena | c42f28d | 2018-04-26 11:33:05 +0100 | [diff] [blame] | 1688 | #if defined(SRC_DIM_1) && defined(SRC_DIM_2) |
| 1689 | /** This OpenCL kernel computes the input transform when the output tile is 4x4, the filter size 3x3 and the data layout is NHWC |
| 1690 | * |
| 1691 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 1692 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
Giorgio Arena | be39f12 | 2018-06-08 17:50:38 +0100 | [diff] [blame] | 1693 | * @note Dimension one of the input tensor (width for NHWC data layout) must be passed at compile time using -DSRC_DIM1 (e.g. -DSRC_DIM_1=112) |
| 1694 | * @note Dimension two of the input tensor (height for NHWC data layout) must be passed at compile time using -DSRC_DIM2 (e.g. -DSRC_DIM_2=112) |
Giorgio Arena | c42f28d | 2018-04-26 11:33:05 +0100 | [diff] [blame] | 1695 | * |
| 1696 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 1697 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 1698 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 1699 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 1700 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 1701 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 1702 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 1703 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 1704 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 1705 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 1706 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 1707 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 1708 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 1709 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 1710 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 1711 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 1712 | */ |
| 1713 | __kernel void winograd_input_transform_4x4_3x3_stepz1_nhwc( |
| 1714 | TENSOR3D_DECLARATION(src), |
| 1715 | TENSOR3D_DECLARATION(dst)) |
| 1716 | { |
| 1717 | int x = get_global_id(0); |
| 1718 | int y = get_global_id(1); |
| 1719 | int z = get_global_id(2); |
| 1720 | |
| 1721 | __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * src_stride_x; |
| 1722 | |
| 1723 | // Clamp coordinates. This clamp is valid for all rows |
| 1724 | int4 y_coord0 = (int4)(y * 4) + (int4)(0, 1, 2, 3) - (int4)PAD_LEFT; |
| 1725 | int2 y_coord1 = (int2)(y * 4) + (int2)(4, 5) - (int2)PAD_LEFT; |
| 1726 | y_coord0 = clamp(y_coord0, -1, SRC_DIM_1); |
| 1727 | y_coord1 = clamp(y_coord1, -1, SRC_DIM_1); |
| 1728 | |
| 1729 | // Row4 |
| 1730 | int z_coord = (z * 4) - PAD_TOP + 4; |
| 1731 | |
| 1732 | // If z < 0, set y to -1 |
| 1733 | int4 valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); |
| 1734 | int2 valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); |
| 1735 | // If z >= SRC_DIM_2, set y to SRC_DIM_2 |
| 1736 | valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2); |
| 1737 | valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2); |
| 1738 | |
| 1739 | // Clamp z coordinate |
| 1740 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 1741 | |
| 1742 | float d40 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1743 | float d41 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1744 | float d42 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 1745 | float d43 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 1746 | float d44 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1747 | float d45 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1748 | |
| 1749 | float k0 = d44; |
| 1750 | float k1 = d44; |
| 1751 | float k2 = d44; |
| 1752 | float k3 = d44; |
| 1753 | float k4 = d44; |
| 1754 | float k5 = (float)0.0f; |
| 1755 | |
| 1756 | k0 += 4.0f * d40 - 5.0f * d42; |
| 1757 | k1 += -4.0f * d41 - 4.0f * d42 + d43; |
| 1758 | k2 += 4.0f * d41 - 4.0f * d42 - d43; |
| 1759 | k3 += -2.0f * d41 + 2.0f * d43 - d42; |
| 1760 | k4 += 2.0f * d41 - 2.0f * d43 - d42; |
| 1761 | k5 += 4.0f * d41 - 5.0f * d43 + d45; |
| 1762 | |
| 1763 | // Row0 |
| 1764 | z_coord = (z * 4) - PAD_TOP + 0; |
| 1765 | |
| 1766 | #if PAD_TOP != 0 |
| 1767 | valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); |
| 1768 | valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); |
| 1769 | valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2); |
| 1770 | valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2); |
| 1771 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 1772 | #else // PAD_TOP != 0 |
| 1773 | valid_y0 = y_coord0; |
| 1774 | valid_y1 = y_coord1; |
| 1775 | #endif // if PAD_TOP == 0, we cannot read out of bound |
| 1776 | |
| 1777 | float d00 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1778 | float d01 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1779 | float d02 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 1780 | float d03 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 1781 | float d04 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1782 | float d05 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1783 | |
| 1784 | // Row2 |
| 1785 | z_coord = (z * 4) - PAD_TOP + 2; |
| 1786 | valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); |
| 1787 | valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); |
| 1788 | valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2); |
| 1789 | valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2); |
| 1790 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 1791 | |
| 1792 | float d20 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1793 | float d21 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1794 | float d22 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 1795 | float d23 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 1796 | float d24 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1797 | float d25 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1798 | |
| 1799 | // Compute destination address |
| 1800 | __global float *dst_addr = (__global float *)(dst_ptr + dst_offset_first_element_in_bytes + x * dst_stride_x + (y + z * (int)NUM_TILES_X) * dst_stride_y); |
| 1801 | |
| 1802 | uint dst_plane_stride = dst_stride_z / sizeof(float); |
| 1803 | |
| 1804 | float out0 = k0; |
| 1805 | float out1 = k1; |
| 1806 | float out2 = k2; |
| 1807 | float out3 = k3; |
| 1808 | float out4 = k4; |
| 1809 | float out5 = k5; |
| 1810 | float out6 = k0; |
| 1811 | float out7 = k1; |
| 1812 | float out8 = k2; |
| 1813 | float out9 = k3; |
| 1814 | float out10 = k4; |
| 1815 | float out11 = k5; |
| 1816 | float out12 = k0; |
| 1817 | float out13 = k1; |
| 1818 | float out14 = k2; |
| 1819 | float out15 = k3; |
| 1820 | float out16 = k4; |
| 1821 | float out17 = k5; |
| 1822 | float out18 = k0; |
| 1823 | float out19 = k1; |
| 1824 | float out20 = k2; |
| 1825 | float out21 = k3; |
| 1826 | float out22 = k4; |
| 1827 | float out23 = k5; |
| 1828 | float out24 = k0; |
| 1829 | float out25 = k1; |
| 1830 | float out26 = k2; |
| 1831 | float out27 = k3; |
| 1832 | float out28 = k4; |
| 1833 | float out29 = k5; |
| 1834 | |
| 1835 | // Channels [0, 5]: [out00, out01, out02, out03, out04, out05] |
| 1836 | out0 += 16.0f * d00 - 20.0f * d02 - 20.0f * d20 + 25.0f * d22 + 4.0f * d04 - 5.0f * d24; |
| 1837 | out1 += -16.0f * d01 - 16.0f * d02 + 4.0f * d03 + 20.0f * d21 + 20.0f * d22 - 5.0f * d23 + 4.0f * d04 - 5.0f * d24; |
| 1838 | out2 += 16.0f * d01 - 16.0f * d02 - 4.0f * d03 - 20.0f * d21 + 20.0f * d22 + 5.0f * d23 + 4.0f * d04 - 5.0f * d24; |
| 1839 | out3 += -8.0f * d01 - 4.0f * d02 + 8.0f * d03 + 10.0f * d21 + 5.0f * d22 - 10.0f * d23 + 4.0f * d04 - 5.0f * d24; |
| 1840 | out4 += 8.0f * d01 - 4.0f * d02 - 8.0f * d03 - 10.0f * d21 + 5.0f * d22 + 10.0f * d23 + 4.0f * d04 - 5.0f * d24; |
| 1841 | out5 += 16.0f * d01 - 20.0f * d03 - 20.0f * d21 + 4.0f * d05 + 25.0f * d23 - 5.0f * d25; |
| 1842 | |
| 1843 | *((__global float *)dst_addr) = out0; |
| 1844 | dst_addr += dst_plane_stride; |
| 1845 | *((__global float *)dst_addr) = out1; |
| 1846 | dst_addr += dst_plane_stride; |
| 1847 | *((__global float *)dst_addr) = out2; |
| 1848 | dst_addr += dst_plane_stride; |
| 1849 | *((__global float *)dst_addr) = out3; |
| 1850 | dst_addr += dst_plane_stride; |
| 1851 | *((__global float *)dst_addr) = out4; |
| 1852 | dst_addr += dst_plane_stride; |
| 1853 | *((__global float *)dst_addr) = out5; |
| 1854 | dst_addr += dst_plane_stride; |
| 1855 | |
| 1856 | // Row1 |
| 1857 | z_coord = (z * 4) - PAD_TOP + 1; |
| 1858 | // Row1 can never be out of bounds |
| 1859 | valid_y0 = y_coord0; |
| 1860 | valid_y1 = y_coord1; |
| 1861 | |
| 1862 | float d10 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1863 | float d11 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1864 | float d12 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 1865 | float d13 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 1866 | float d14 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1867 | float d15 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1868 | |
| 1869 | // Row3 |
| 1870 | z_coord = (z * 4) - PAD_TOP + 3; |
| 1871 | valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); |
| 1872 | valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); |
| 1873 | valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2); |
| 1874 | valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2); |
| 1875 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 1876 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 1877 | |
| 1878 | float d30 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1879 | float d31 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1880 | float d32 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 1881 | float d33 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 1882 | float d34 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 1883 | float d35 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 1884 | |
| 1885 | // Compute common parts for the channels between [6, 29] |
| 1886 | // Channels [6, 11]: [out10, out11, out12, out13, out14, out15] |
| 1887 | // Channels [12, 17]: [out20, out21, out22, out23, out24, out25] |
| 1888 | float part0 = -16.0f * d20 + 20.0f * d22 - 4.0f * d24; |
| 1889 | float part1 = 16.0f * d10 - 20.0f * d12 + 4.0f * d14 - 4.0f * d30 + 5.0f * d32 - d34; |
| 1890 | float part2 = 16.0f * d22 - 4.0f * d24; |
| 1891 | float part3 = 16.0f * d21 - 4.0f * d23; |
| 1892 | float part4 = 16.0f * d12 - 4.0f * d14 - 4.0f * d32 + d34; |
| 1893 | float part5 = 16.0f * d11 - 4.0f * d13 - 4.0f * d31 + d33; |
| 1894 | float part6 = 4.0f * d22 - 4.0f * d24; |
| 1895 | float part7 = 8.0f * d11 - 8.0f * d13 - 2.0f * d31 + 2.0f * d33; |
| 1896 | float part8 = 4.0f * d12 - 4.0f * d14 - d32 + d34; |
| 1897 | float part9 = 8.0f * d21 - 8.0f * d23; |
| 1898 | float part10 = -16.0f * d21 + 20.0f * d23 - 4.0f * d25; |
| 1899 | float part11 = -16.0f * d11 + 20.0f * d13 - 4.0f * d15 + 4.0f * d31 - 5.0f * d33 + d35; |
| 1900 | |
| 1901 | // Channels [18, 23]: [out30, out31, out32, out33, out34, out35] |
| 1902 | // Channels [24, 29]: [out40, out41, out42, out43, out44, out45] |
| 1903 | float part12 = 8.0f * d10 - 10.0f * d12 + 2.0f * d14 - 8.0f * d30 + 10.0f * d32 - 2.0f * d34; |
| 1904 | float part13 = part0 * 0.25f; // -4.0f * d20 + 5.0f * d22 - d24 |
| 1905 | float part14 = part2 * 0.25f; // 4.0f * d22 - d24 |
| 1906 | float part15 = 8.0f * d11 - 2.0f * d13 - 8.0f * d31 + 2.0f * d33; |
| 1907 | float part16 = 8.0f * d12 - 2.0f * d14 - 8.0f * d32 + 2.0f * d34; |
| 1908 | float part17 = part3 * 0.25f; // 4.0f * d21 - d23 |
| 1909 | float part18 = part6 * 0.25f; // d22 - d24 |
| 1910 | float part19 = 4.0f * d11 - 4.0f * d13 - 4.0f * d31 + 4.0f * d33; |
| 1911 | float part20 = 2.0f * d12 - 2.0f * d14 - 2.0f * d32 + 2.0f * d34; |
| 1912 | float part21 = part9 * 0.25f; // 2.0f * (d21 - d23) |
| 1913 | float part22 = part10 * 0.25f; // - 4.0f * d21 + 5.0f * d23 - d25 |
| 1914 | float part23 = part11 * 0.5f + 6.0f * d31 - 7.5f * d33 + 1.5f * d35; // - 8.0f * d11 + 10.0f * d13 - 2.0f * d15 + 8.0f * d31 - 10.0f * d33 + 2.0f * d35; |
| 1915 | |
| 1916 | out6 += part0 - part1; |
| 1917 | out12 += part0 + part1; |
| 1918 | out7 += part2 + part3 + part4 + part5; |
| 1919 | out8 += part2 - part3 + part4 - part5; |
| 1920 | out13 += part2 + part3 - part4 - part5; |
| 1921 | out14 += part2 - part3 - part4 + part5; |
| 1922 | out9 += part6 + part7 + part8 + part9; |
| 1923 | out10 += part6 - part7 + part8 - part9; |
| 1924 | out15 += part6 - part7 - part8 + part9; |
| 1925 | out16 += part6 + part7 - part8 - part9; |
| 1926 | out11 += part10 + part11; |
| 1927 | out17 += part10 - part11; |
| 1928 | |
| 1929 | out18 += part13 - part12; |
| 1930 | out24 += part13 + part12; |
| 1931 | out19 += part14 + part15 + part16 + part17; |
| 1932 | out20 += part14 - part15 + part16 - part17; |
| 1933 | out25 += part14 - part15 - part16 + part17; |
| 1934 | out26 += part14 + part15 - part16 - part17; |
| 1935 | out21 += part18 + part19 + part20 + part21; |
| 1936 | out22 += part18 - part19 + part20 - part21; |
| 1937 | out27 += part18 - part19 - part20 + part21; |
| 1938 | out28 += part18 + part19 - part20 - part21; |
| 1939 | out23 += part22 + part23; |
| 1940 | out29 += part22 - part23; |
| 1941 | |
| 1942 | *((__global float *)dst_addr) = out6; |
| 1943 | dst_addr += dst_plane_stride; |
| 1944 | *((__global float *)dst_addr) = out7; |
| 1945 | dst_addr += dst_plane_stride; |
| 1946 | *((__global float *)dst_addr) = out8; |
| 1947 | dst_addr += dst_plane_stride; |
| 1948 | *((__global float *)dst_addr) = out9; |
| 1949 | dst_addr += dst_plane_stride; |
| 1950 | *((__global float *)dst_addr) = out10; |
| 1951 | dst_addr += dst_plane_stride; |
| 1952 | *((__global float *)dst_addr) = out11; |
| 1953 | dst_addr += dst_plane_stride; |
| 1954 | *((__global float *)dst_addr) = out12; |
| 1955 | dst_addr += dst_plane_stride; |
| 1956 | *((__global float *)dst_addr) = out13; |
| 1957 | dst_addr += dst_plane_stride; |
| 1958 | *((__global float *)dst_addr) = out14; |
| 1959 | dst_addr += dst_plane_stride; |
| 1960 | *((__global float *)dst_addr) = out15; |
| 1961 | dst_addr += dst_plane_stride; |
| 1962 | *((__global float *)dst_addr) = out16; |
| 1963 | dst_addr += dst_plane_stride; |
| 1964 | *((__global float *)dst_addr) = out17; |
| 1965 | dst_addr += dst_plane_stride; |
| 1966 | |
| 1967 | *((__global float *)dst_addr) = out18; |
| 1968 | dst_addr += dst_plane_stride; |
| 1969 | *((__global float *)dst_addr) = out19; |
| 1970 | dst_addr += dst_plane_stride; |
| 1971 | *((__global float *)dst_addr) = out20; |
| 1972 | dst_addr += dst_plane_stride; |
| 1973 | *((__global float *)dst_addr) = out21; |
| 1974 | dst_addr += dst_plane_stride; |
| 1975 | *((__global float *)dst_addr) = out22; |
| 1976 | dst_addr += dst_plane_stride; |
| 1977 | *((__global float *)dst_addr) = out23; |
| 1978 | dst_addr += dst_plane_stride; |
| 1979 | *((__global float *)dst_addr) = out24; |
| 1980 | dst_addr += dst_plane_stride; |
| 1981 | *((__global float *)dst_addr) = out25; |
| 1982 | dst_addr += dst_plane_stride; |
| 1983 | *((__global float *)dst_addr) = out26; |
| 1984 | dst_addr += dst_plane_stride; |
| 1985 | *((__global float *)dst_addr) = out27; |
| 1986 | dst_addr += dst_plane_stride; |
| 1987 | *((__global float *)dst_addr) = out28; |
| 1988 | dst_addr += dst_plane_stride; |
| 1989 | *((__global float *)dst_addr) = out29; |
| 1990 | dst_addr += dst_plane_stride; |
| 1991 | |
| 1992 | // Row5 |
| 1993 | z_coord = (z * 4) - PAD_TOP + 5; |
| 1994 | valid_y0 = select(y_coord0, -1, (int4)z_coord < 0); |
| 1995 | valid_y1 = select(y_coord1, -1, (int2)z_coord < 0); |
| 1996 | valid_y0 = select(valid_y0, SRC_DIM_1, (int4)z_coord >= SRC_DIM_2); |
| 1997 | valid_y1 = select(valid_y1, SRC_DIM_1, (int2)z_coord >= SRC_DIM_2); |
| 1998 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 1999 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 2000 | |
| 2001 | float d50 = *(__global float *)(src_addr + valid_y0.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2002 | float d51 = *(__global float *)(src_addr + valid_y0.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2003 | float d52 = *(__global float *)(src_addr + valid_y0.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 2004 | float d53 = *(__global float *)(src_addr + valid_y0.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 2005 | float d54 = *(__global float *)(src_addr + valid_y1.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2006 | float d55 = *(__global float *)(src_addr + valid_y1.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2007 | |
| 2008 | // Channels [30, 35] |
| 2009 | out0 = 16.0f * d10 - 20.0f * d12 - 20.0f * d30 + 25.0f * d32 + 4.0f * d50 - 5.0f * d52 + d54 + 4.0f * d14 - 5.0f * d34; |
| 2010 | out1 = -16.0f * d11 - 16.0f * d12 + 4.0f * d13 + 20.0f * d31 + 20.0f * d32 - 5.0f * d33 - 4.0f * d51 - 4.0f * d52 + d53 + d54 + 4.0f * d14 - 5.0f * d34; |
| 2011 | out2 = 16.0f * d11 - 16.0f * d12 - 4.0f * d13 - 20.0f * d31 + 20.0f * d32 + 5.0f * d33 + 4.0f * d51 - 4.0f * d52 - d53 + d54 + 4.0f * d14 - 5.0f * d34; |
| 2012 | out3 = -8.0f * d11 - 4.0f * d12 + 8.0f * d13 + 10.0f * d31 - 10.0f * d33 + 5.0f * d32 - 2.0f * d51 + 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34; |
| 2013 | out4 = 8.0f * d11 - 4.0f * d12 - 8.0f * d13 - 10.0f * d31 + 5.0f * d32 + 10.0f * d33 + 2.0f * d51 - 2.0f * d53 - d52 + d54 + 4.0f * d14 - 5.0f * d34; |
| 2014 | out5 = 16.0f * d11 - 20.0f * d13 + 4.0f * d15 - 20.0f * d31 + 25.0f * d33 - 5.0f * d35 + 4.0f * d51 - 5.0f * d53 + d55; |
| 2015 | |
| 2016 | *((__global float *)dst_addr) = out0; |
| 2017 | dst_addr += dst_plane_stride; |
| 2018 | *((__global float *)dst_addr) = out1; |
| 2019 | dst_addr += dst_plane_stride; |
| 2020 | *((__global float *)dst_addr) = out2; |
| 2021 | dst_addr += dst_plane_stride; |
| 2022 | *((__global float *)dst_addr) = out3; |
| 2023 | dst_addr += dst_plane_stride; |
| 2024 | *((__global float *)dst_addr) = out4; |
| 2025 | dst_addr += dst_plane_stride; |
| 2026 | *((__global float *)dst_addr) = out5; |
| 2027 | dst_addr += dst_plane_stride; |
| 2028 | } |
| 2029 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2030 | #endif // defined(SRC_DIM_1) && defined(SRC_DIM_2) |
Giorgio Arena | c42f28d | 2018-04-26 11:33:05 +0100 | [diff] [blame] | 2031 | |
Giorgio Arena | fe5ef38 | 2018-04-17 10:14:10 +0100 | [diff] [blame] | 2032 | #define OUTPUT_ROW_4x4_5x5(out, tmp, comm_fact) \ |
| 2033 | ({ \ |
| 2034 | comm_fact.s0 = tmp.s2 - 4.25f * tmp.s4 + tmp.s6; \ |
| 2035 | comm_fact.s1 = tmp.s1 - 4.25f * tmp.s3 + tmp.s5; \ |
| 2036 | comm_fact.s2 = 2.5f * tmp.s3; \ |
| 2037 | comm_fact.s3 = 0.5f * tmp.s1 + 2.f * tmp.s5 - comm_fact.s2; \ |
| 2038 | comm_fact.s4 = 0.25f * tmp.s2 - 1.25f * tmp.s4 + tmp.s6; \ |
| 2039 | comm_fact.s5 = 4.f * tmp.s2 + tmp.s6 - 5.f * tmp.s4; \ |
| 2040 | comm_fact.s6 = 2.f * tmp.s1 + 0.5f * tmp.s5 - comm_fact.s2; \ |
| 2041 | \ |
| 2042 | out.s0 = tmp.s0 - tmp.s6 + 5.25f * tmp.s4 - 5.25f * tmp.s2; \ |
| 2043 | out.s1 = comm_fact.s0 + comm_fact.s1; \ |
| 2044 | out.s2 = comm_fact.s0 - comm_fact.s1; \ |
| 2045 | out.s3 = comm_fact.s3 + comm_fact.s4; \ |
| 2046 | out.s4 = comm_fact.s4 - comm_fact.s3; \ |
| 2047 | out.s5 = comm_fact.s5 + comm_fact.s6; \ |
| 2048 | out.s6 = comm_fact.s5 - comm_fact.s6; \ |
| 2049 | out.s7 = tmp.s7 - tmp.s1 + 5.25f * tmp.s3 - 5.25f * tmp.s5; \ |
| 2050 | }) |
| 2051 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2052 | /** This OpenCL kernel computes the input transform when the kernel size is 5x5 and the output tile is 4x4 when the data layout is NCHW |
Giorgio Arena | fe5ef38 | 2018-04-17 10:14:10 +0100 | [diff] [blame] | 2053 | * |
| 2054 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 2055 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
| 2056 | * |
| 2057 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 2058 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 2059 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2060 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 2061 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2062 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 2063 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2064 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2065 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 2066 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2067 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2068 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2069 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2070 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 2071 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2072 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2073 | */ |
| 2074 | __kernel void winograd_input_transform_4x4_5x5_stepz1_nchw( |
| 2075 | TENSOR3D_DECLARATION(src), |
| 2076 | TENSOR3D_DECLARATION(dst)) |
| 2077 | { |
| 2078 | int x = get_global_id(0); |
| 2079 | int y = get_global_id(1); |
| 2080 | int z = get_global_id(2); |
| 2081 | |
| 2082 | // Compute input address |
| 2083 | __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * 4 * src_stride_x + y * 4 * src_stride_y + z * src_stride_z; |
| 2084 | |
| 2085 | src_addr = src_addr - ((int)PAD_LEFT * src_stride_x) - ((int)PAD_TOP * src_stride_y); |
| 2086 | |
| 2087 | // Load 8x8 input tile |
| 2088 | const float8 in_row0 = vload8(0, (__global float *)(src_addr + 0 * src_stride_y)); |
| 2089 | const float8 in_row1 = vload8(0, (__global float *)(src_addr + 1 * src_stride_y)); |
| 2090 | const float8 in_row2 = vload8(0, (__global float *)(src_addr + 2 * src_stride_y)); |
| 2091 | const float8 in_row3 = vload8(0, (__global float *)(src_addr + 3 * src_stride_y)); |
| 2092 | const float8 in_row4 = vload8(0, (__global float *)(src_addr + 4 * src_stride_y)); |
| 2093 | const float8 in_row5 = vload8(0, (__global float *)(src_addr + 5 * src_stride_y)); |
| 2094 | const float8 in_row6 = vload8(0, (__global float *)(src_addr + 6 * src_stride_y)); |
| 2095 | const float8 in_row7 = vload8(0, (__global float *)(src_addr + 7 * src_stride_y)); |
| 2096 | |
| 2097 | // Calculate common factors for intermediate tensor |
| 2098 | float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4; |
| 2099 | float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3; |
| 2100 | float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6; |
| 2101 | |
| 2102 | // Calculate intermediate tensor and reuse common factor vectors |
| 2103 | const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2; |
| 2104 | const float8 tmp1 = comm_fact0 + comm_fact1; |
| 2105 | const float8 tmp2 = comm_fact0 - comm_fact1; |
| 2106 | |
| 2107 | comm_fact0 = 2.5f * in_row3; |
| 2108 | comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5; |
| 2109 | |
| 2110 | const float8 tmp3 = comm_fact1 + comm_fact2; |
| 2111 | const float8 tmp4 = comm_fact2 - comm_fact1; |
| 2112 | |
| 2113 | comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5; |
| 2114 | comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6; |
| 2115 | |
| 2116 | const float8 tmp5 = comm_fact1 + comm_fact2; |
| 2117 | const float8 tmp6 = comm_fact2 - comm_fact1; |
| 2118 | const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5; |
| 2119 | |
| 2120 | // Calculate output rows (reuse comm_fact0 vector) |
| 2121 | float8 out0, out1, out2, out3, out4, out5, out6, out7; |
| 2122 | |
| 2123 | OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0); |
| 2124 | OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0); |
| 2125 | OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0); |
| 2126 | OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0); |
| 2127 | OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0); |
| 2128 | OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0); |
| 2129 | OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0); |
| 2130 | OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0); |
| 2131 | |
| 2132 | // Store values across the 64 channels |
| 2133 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + z * dst_stride_x + (x + y * (int)NUM_TILES_X) * dst_stride_y; |
| 2134 | |
| 2135 | *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0; |
| 2136 | *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1; |
| 2137 | *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2; |
| 2138 | *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3; |
| 2139 | *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4; |
| 2140 | *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5; |
| 2141 | *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6; |
| 2142 | *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7; |
| 2143 | *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0; |
| 2144 | *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1; |
| 2145 | *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2; |
| 2146 | *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3; |
| 2147 | *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4; |
| 2148 | *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5; |
| 2149 | *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6; |
| 2150 | *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7; |
| 2151 | *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0; |
| 2152 | *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1; |
| 2153 | *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2; |
| 2154 | *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3; |
| 2155 | *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4; |
| 2156 | *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5; |
| 2157 | *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6; |
| 2158 | *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7; |
| 2159 | *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0; |
| 2160 | *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1; |
| 2161 | *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2; |
| 2162 | *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3; |
| 2163 | *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4; |
| 2164 | *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5; |
| 2165 | *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6; |
| 2166 | *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7; |
| 2167 | *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0; |
| 2168 | *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1; |
| 2169 | *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2; |
| 2170 | *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3; |
| 2171 | *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4; |
| 2172 | *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5; |
| 2173 | *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6; |
| 2174 | *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7; |
| 2175 | *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0; |
| 2176 | *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1; |
| 2177 | *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2; |
| 2178 | *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3; |
| 2179 | *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4; |
| 2180 | *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5; |
| 2181 | *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6; |
| 2182 | *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7; |
| 2183 | *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0; |
| 2184 | *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1; |
| 2185 | *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2; |
| 2186 | *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3; |
| 2187 | *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4; |
| 2188 | *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5; |
| 2189 | *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6; |
| 2190 | *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7; |
| 2191 | *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0; |
| 2192 | *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1; |
| 2193 | *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2; |
| 2194 | *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3; |
| 2195 | *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4; |
| 2196 | *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5; |
| 2197 | *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6; |
| 2198 | *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7; |
| 2199 | } |
Giorgio Arena | be39f12 | 2018-06-08 17:50:38 +0100 | [diff] [blame] | 2200 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2201 | #if defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) |
| 2202 | /** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 2x1 |
Giorgio Arena | be39f12 | 2018-06-08 17:50:38 +0100 | [diff] [blame] | 2203 | * |
| 2204 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 2205 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2206 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2 |
| 2207 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1 |
| 2208 | * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 2209 | * |
| 2210 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 2211 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 2212 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2213 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 2214 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2215 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 2216 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2217 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2218 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 2219 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2220 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2221 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2222 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2223 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 2224 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2225 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2226 | */ |
| 2227 | __kernel void winograd_input_transform_2x1_3x1_stepz1_nchw( |
| 2228 | TENSOR3D_DECLARATION(src), |
| 2229 | TENSOR3D_DECLARATION(dst)) |
| 2230 | { |
| 2231 | winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr, |
| 2232 | src_stride_x, |
| 2233 | src_step_x, |
| 2234 | src_stride_y, |
| 2235 | src_step_y, |
| 2236 | src_stride_z, |
| 2237 | src_step_z, |
| 2238 | src_offset_first_element_in_bytes, |
| 2239 | dst_ptr, |
| 2240 | dst_stride_x, |
| 2241 | dst_step_x, |
| 2242 | dst_stride_y, |
| 2243 | dst_step_y, |
| 2244 | dst_stride_z, |
| 2245 | dst_step_z, |
| 2246 | dst_offset_first_element_in_bytes); |
| 2247 | } |
| 2248 | |
| 2249 | /** This OpenCL kernel computes the input transform when the kernel size is 3x1, the output tile is 2x1 and the number of channels is multiple of 2 |
| 2250 | * |
| 2251 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 2252 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
| 2253 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2 |
| 2254 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1 |
| 2255 | * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 2256 | * |
| 2257 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 2258 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 2259 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2260 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 2261 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2262 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 2263 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2264 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2265 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 2266 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2267 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2268 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2269 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2270 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 2271 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2272 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2273 | */ |
| 2274 | __kernel void winograd_input_transform_2x1_3x1_stepz2_nchw( |
| 2275 | TENSOR3D_DECLARATION(src), |
| 2276 | TENSOR3D_DECLARATION(dst)) |
| 2277 | { |
| 2278 | winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr, |
| 2279 | src_stride_x, |
| 2280 | src_step_x, |
| 2281 | src_stride_y, |
| 2282 | src_step_y, |
| 2283 | src_stride_z, |
| 2284 | src_step_z, |
| 2285 | src_offset_first_element_in_bytes, |
| 2286 | dst_ptr, |
| 2287 | dst_stride_x, |
| 2288 | dst_step_x, |
| 2289 | dst_stride_y, |
| 2290 | dst_step_y, |
| 2291 | dst_stride_z, |
| 2292 | dst_step_z, |
| 2293 | dst_offset_first_element_in_bytes); |
| 2294 | } |
| 2295 | |
| 2296 | /** This OpenCL kernel computes the input transform when the kernel size is 3x1 and the output tile is 4x1 |
| 2297 | * |
| 2298 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 2299 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
| 2300 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4 |
| 2301 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1 |
| 2302 | * @note -DWINOGRAD_INPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 2303 | * |
| 2304 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 2305 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 2306 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2307 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 2308 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2309 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 2310 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2311 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2312 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 2313 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2314 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2315 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2316 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2317 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 2318 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2319 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2320 | */ |
| 2321 | __kernel void winograd_input_transform_4x1_3x1_stepz1_nchw( |
| 2322 | TENSOR3D_DECLARATION(src), |
| 2323 | TENSOR3D_DECLARATION(dst)) |
| 2324 | { |
| 2325 | winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr, |
| 2326 | src_stride_x, |
| 2327 | src_step_x, |
| 2328 | src_stride_y, |
| 2329 | src_step_y, |
| 2330 | src_stride_z, |
| 2331 | src_step_z, |
| 2332 | src_offset_first_element_in_bytes, |
| 2333 | dst_ptr, |
| 2334 | dst_stride_x, |
| 2335 | dst_step_x, |
| 2336 | dst_stride_y, |
| 2337 | dst_step_y, |
| 2338 | dst_stride_z, |
| 2339 | dst_step_z, |
| 2340 | dst_offset_first_element_in_bytes); |
| 2341 | } |
| 2342 | #endif // defined(WINOGRAD_INPUT_TRANSFORM_HORIZONTAL) |
| 2343 | |
| 2344 | #if defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 2345 | /** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x2 |
| 2346 | * |
| 2347 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 2348 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
| 2349 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1 |
| 2350 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2 |
| 2351 | * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time |
| 2352 | * |
| 2353 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 2354 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 2355 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2356 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 2357 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2358 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 2359 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2360 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2361 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 2362 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2363 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2364 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2365 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2366 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 2367 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2368 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2369 | */ |
| 2370 | __kernel void winograd_input_transform_1x2_1x3_stepz1_nchw( |
| 2371 | TENSOR3D_DECLARATION(src), |
| 2372 | TENSOR3D_DECLARATION(dst)) |
| 2373 | { |
| 2374 | winograd_input_transform_2x2_3x3_stepz1_nchw(src_ptr, |
| 2375 | src_stride_x, |
| 2376 | src_step_x, |
| 2377 | src_stride_y, |
| 2378 | src_step_y, |
| 2379 | src_stride_z, |
| 2380 | src_step_z, |
| 2381 | src_offset_first_element_in_bytes, |
| 2382 | dst_ptr, |
| 2383 | dst_stride_x, |
| 2384 | dst_step_x, |
| 2385 | dst_stride_y, |
| 2386 | dst_step_y, |
| 2387 | dst_stride_z, |
| 2388 | dst_step_z, |
| 2389 | dst_offset_first_element_in_bytes); |
| 2390 | } |
| 2391 | |
| 2392 | /** This OpenCL kernel computes the input transform when the kernel size is 1x3, the output tile is 1x2 and the number of channels is multiple of 2 |
| 2393 | * |
| 2394 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 2395 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
| 2396 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1 |
| 2397 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2 |
| 2398 | * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time |
| 2399 | * |
| 2400 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 2401 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 2402 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2403 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 2404 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2405 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 2406 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2407 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2408 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 2409 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2410 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2411 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2412 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2413 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 2414 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2415 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2416 | */ |
| 2417 | __kernel void winograd_input_transform_1x2_1x3_stepz2_nchw( |
| 2418 | TENSOR3D_DECLARATION(src), |
| 2419 | TENSOR3D_DECLARATION(dst)) |
| 2420 | { |
| 2421 | winograd_input_transform_2x2_3x3_stepz2_nchw(src_ptr, |
| 2422 | src_stride_x, |
| 2423 | src_step_x, |
| 2424 | src_stride_y, |
| 2425 | src_step_y, |
| 2426 | src_stride_z, |
| 2427 | src_step_z, |
| 2428 | src_offset_first_element_in_bytes, |
| 2429 | dst_ptr, |
| 2430 | dst_stride_x, |
| 2431 | dst_step_x, |
| 2432 | dst_stride_y, |
| 2433 | dst_step_y, |
| 2434 | dst_stride_z, |
| 2435 | dst_step_z, |
| 2436 | dst_offset_first_element_in_bytes); |
| 2437 | } |
| 2438 | |
| 2439 | /** This OpenCL kernel computes the input transform when the kernel size is 1x3 and the output tile is 1x4 |
| 2440 | * |
| 2441 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 2442 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
| 2443 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1 |
| 2444 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4 |
| 2445 | * @note -DWINOGRAD_INPUT_TRANSFORM_VERTICAL has to be passed at compile time |
| 2446 | * |
| 2447 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 2448 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 2449 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2450 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 2451 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2452 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 2453 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2454 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2455 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 2456 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2457 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2458 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2459 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2460 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 2461 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2462 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2463 | */ |
| 2464 | __kernel void winograd_input_transform_1x4_1x3_stepz1_nchw( |
| 2465 | TENSOR3D_DECLARATION(src), |
| 2466 | TENSOR3D_DECLARATION(dst)) |
| 2467 | { |
| 2468 | winograd_input_transform_4x4_3x3_stepz1_nchw(src_ptr, |
| 2469 | src_stride_x, |
| 2470 | src_step_x, |
| 2471 | src_stride_y, |
| 2472 | src_step_y, |
| 2473 | src_stride_z, |
| 2474 | src_step_z, |
| 2475 | src_offset_first_element_in_bytes, |
| 2476 | dst_ptr, |
| 2477 | dst_stride_x, |
| 2478 | dst_step_x, |
| 2479 | dst_stride_y, |
| 2480 | dst_step_y, |
| 2481 | dst_stride_z, |
| 2482 | dst_step_z, |
| 2483 | dst_offset_first_element_in_bytes); |
| 2484 | } |
| 2485 | #endif // defined(WINOGRAD_INPUT_TRANSFORM_VERTICAL) |
| 2486 | |
| 2487 | #if defined(SRC_DIM_1) && defined(SRC_DIM_2) |
| 2488 | /** This OpenCL kernel computes the input transform when the kernel size is 5x5 and the output tile is 4x4 when the data layout is NHWC |
| 2489 | * |
| 2490 | * @note The number of tiles in the x axis must be passed at compile time using -DNUM_TILES_X (i.e.-DNUM_TILES_X=5). |
| 2491 | * @note The pad left and pad top must be passed at compile time using -DPAD_LEFT and -DPAD_TOP (i.e.-DPAD_LEFT=1 and -DPAD_TOP=0). |
| 2492 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4 |
| 2493 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4 |
Giorgio Arena | be39f12 | 2018-06-08 17:50:38 +0100 | [diff] [blame] | 2494 | * |
| 2495 | * @param[in] src_ptr Pointer to the source image. Supported data types: F32 |
| 2496 | * @param[in] src_stride_x Stride of the source image in X dimension (in bytes) |
| 2497 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2498 | * @param[in] src_stride_y Stride of the source image in Y dimension (in bytes) |
| 2499 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2500 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source image |
| 2501 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2502 | * @param[in] src_step_z src_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2503 | * @param[in] dst_ptr Pointer to the destination tensor. Supported data types: as @p src_ptr |
| 2504 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2505 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2506 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2507 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2508 | * @param[in] dst_stride_z Stride of the destination tensor in Z dimension (in bytes) |
| 2509 | * @param[in] dst_step_z dst_stride_z * number of elements along Y processed per workitem(in bytes) |
| 2510 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2511 | */ |
| 2512 | __kernel void winograd_input_transform_4x4_5x5_stepz1_nhwc( |
| 2513 | TENSOR3D_DECLARATION(src), |
| 2514 | TENSOR3D_DECLARATION(dst)) |
| 2515 | { |
| 2516 | int x = get_global_id(0); |
| 2517 | int y = get_global_id(1); |
| 2518 | int z = get_global_id(2); |
| 2519 | |
| 2520 | // Compute input address |
| 2521 | __global uchar *src_addr = src_ptr + src_offset_first_element_in_bytes + x * sizeof(float); |
| 2522 | |
| 2523 | // Clamp coordinates. This clamp is valid for all rows |
| 2524 | int8 y_coord = (int8)(y * 4) + (int8)(0, 1, 2, 3, 4, 5, 6, 7) - (int8)PAD_LEFT; |
| 2525 | y_coord = clamp(y_coord, -1, SRC_DIM_1); |
| 2526 | |
| 2527 | // Load 8x8 input tile |
| 2528 | float8 in_row0, in_row1, in_row2, in_row3, in_row4, in_row5, in_row6, in_row7; |
| 2529 | |
| 2530 | // Row0 |
| 2531 | int z_coord = (z * 4) - PAD_TOP + 0; |
| 2532 | int8 valid_y = select(y_coord, -1, (int8)z_coord < 0); // If z < 0, set y to -1 |
| 2533 | valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); // If z >= SRC_DIM_2, set y to SRC_DIM_2 |
| 2534 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); // Clamp z coordinate |
| 2535 | |
| 2536 | in_row0.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2537 | in_row0.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2538 | in_row0.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 2539 | in_row0.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 2540 | in_row0.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); |
| 2541 | in_row0.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); |
| 2542 | in_row0.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); |
| 2543 | in_row0.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); |
| 2544 | |
| 2545 | // Row1 |
| 2546 | z_coord = (z * 4) - PAD_TOP + 1; |
| 2547 | valid_y = select(y_coord, -1, (int8)z_coord < 0); |
| 2548 | valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); |
| 2549 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 2550 | |
| 2551 | in_row1.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2552 | in_row1.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2553 | in_row1.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 2554 | in_row1.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 2555 | in_row1.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); |
| 2556 | in_row1.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); |
| 2557 | in_row1.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); |
| 2558 | in_row1.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); |
| 2559 | |
| 2560 | // Row2 |
| 2561 | z_coord = (z * 4) - PAD_TOP + 2; |
| 2562 | valid_y = select(y_coord, -1, (int8)z_coord < 0); |
| 2563 | valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); |
| 2564 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 2565 | |
| 2566 | in_row2.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2567 | in_row2.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2568 | in_row2.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 2569 | in_row2.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 2570 | in_row2.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); |
| 2571 | in_row2.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); |
| 2572 | in_row2.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); |
| 2573 | in_row2.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); |
| 2574 | |
| 2575 | // Row3 |
| 2576 | z_coord = (z * 4) - PAD_TOP + 3; |
| 2577 | valid_y = select(y_coord, -1, (int8)z_coord < 0); |
| 2578 | valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); |
| 2579 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 2580 | |
| 2581 | in_row3.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2582 | in_row3.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2583 | in_row3.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 2584 | in_row3.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 2585 | in_row3.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); |
| 2586 | in_row3.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); |
| 2587 | in_row3.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); |
| 2588 | in_row3.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); |
| 2589 | |
| 2590 | // Row4 |
| 2591 | z_coord = (z * 4) - PAD_TOP + 4; |
| 2592 | valid_y = select(y_coord, -1, (int8)z_coord < 0); |
| 2593 | valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); |
| 2594 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 2595 | |
| 2596 | in_row4.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2597 | in_row4.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2598 | in_row4.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 2599 | in_row4.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 2600 | in_row4.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); |
| 2601 | in_row4.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); |
| 2602 | in_row4.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); |
| 2603 | in_row4.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); |
| 2604 | |
| 2605 | // Row5 |
| 2606 | z_coord = (z * 4) - PAD_TOP + 5; |
| 2607 | valid_y = select(y_coord, -1, (int8)z_coord < 0); |
| 2608 | valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); |
| 2609 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 2610 | |
| 2611 | in_row5.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2612 | in_row5.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2613 | in_row5.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 2614 | in_row5.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 2615 | in_row5.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); |
| 2616 | in_row5.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); |
| 2617 | in_row5.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); |
| 2618 | in_row5.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); |
| 2619 | |
| 2620 | // Row6 |
| 2621 | z_coord = (z * 4) - PAD_TOP + 6; |
| 2622 | valid_y = select(y_coord, -1, (int8)z_coord < 0); |
| 2623 | valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); |
| 2624 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 2625 | |
| 2626 | in_row6.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2627 | in_row6.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2628 | in_row6.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 2629 | in_row6.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 2630 | in_row6.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); |
| 2631 | in_row6.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); |
| 2632 | in_row6.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); |
| 2633 | in_row6.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); |
| 2634 | |
| 2635 | // Row7 |
| 2636 | z_coord = (z * 4) - PAD_TOP + 7; |
| 2637 | valid_y = select(y_coord, -1, (int8)z_coord < 0); |
| 2638 | valid_y = select(valid_y, SRC_DIM_1, (int8)z_coord >= SRC_DIM_2); |
| 2639 | z_coord = clamp(z_coord, 0, SRC_DIM_2 - 1); |
| 2640 | |
| 2641 | in_row7.s0 = *(__global float *)(src_addr + valid_y.s0 * (int)src_stride_y + z_coord * src_stride_z); |
| 2642 | in_row7.s1 = *(__global float *)(src_addr + valid_y.s1 * (int)src_stride_y + z_coord * src_stride_z); |
| 2643 | in_row7.s2 = *(__global float *)(src_addr + valid_y.s2 * (int)src_stride_y + z_coord * src_stride_z); |
| 2644 | in_row7.s3 = *(__global float *)(src_addr + valid_y.s3 * (int)src_stride_y + z_coord * src_stride_z); |
| 2645 | in_row7.s4 = *(__global float *)(src_addr + valid_y.s4 * (int)src_stride_y + z_coord * src_stride_z); |
| 2646 | in_row7.s5 = *(__global float *)(src_addr + valid_y.s5 * (int)src_stride_y + z_coord * src_stride_z); |
| 2647 | in_row7.s6 = *(__global float *)(src_addr + valid_y.s6 * (int)src_stride_y + z_coord * src_stride_z); |
| 2648 | in_row7.s7 = *(__global float *)(src_addr + valid_y.s7 * (int)src_stride_y + z_coord * src_stride_z); |
| 2649 | |
| 2650 | // Calculate common factors for intermediate tensor |
| 2651 | float8 comm_fact0 = in_row2 + in_row6 - 4.25f * in_row4; |
| 2652 | float8 comm_fact1 = in_row1 + in_row5 - 4.25f * in_row3; |
| 2653 | float8 comm_fact2 = 0.25f * in_row2 - 1.25f * in_row4 + in_row6; |
| 2654 | |
| 2655 | // Calculate intermediate tensor and reuse common factor vectors |
| 2656 | const float8 tmp0 = in_row0 - in_row6 + 5.25f * in_row4 - 5.25f * in_row2; |
| 2657 | const float8 tmp1 = comm_fact0 + comm_fact1; |
| 2658 | const float8 tmp2 = comm_fact0 - comm_fact1; |
| 2659 | |
| 2660 | comm_fact0 = 2.5f * in_row3; |
| 2661 | comm_fact1 = 0.5f * in_row1 - comm_fact0 + 2.f * in_row5; |
| 2662 | |
| 2663 | const float8 tmp3 = comm_fact1 + comm_fact2; |
| 2664 | const float8 tmp4 = comm_fact2 - comm_fact1; |
| 2665 | |
| 2666 | comm_fact1 = 2.f * in_row1 - comm_fact0 + 0.5f * in_row5; |
| 2667 | comm_fact2 = 4.f * in_row2 - 5.f * in_row4 + in_row6; |
| 2668 | |
| 2669 | const float8 tmp5 = comm_fact1 + comm_fact2; |
| 2670 | const float8 tmp6 = comm_fact2 - comm_fact1; |
| 2671 | const float8 tmp7 = in_row7 - in_row1 + 5.25f * in_row3 - 5.25f * in_row5; |
| 2672 | |
| 2673 | // Calculate output rows (reuse comm_fact0 vector) |
| 2674 | float8 out0, out1, out2, out3, out4, out5, out6, out7; |
| 2675 | |
| 2676 | OUTPUT_ROW_4x4_5x5(out0, tmp0, comm_fact0); |
| 2677 | OUTPUT_ROW_4x4_5x5(out1, tmp1, comm_fact0); |
| 2678 | OUTPUT_ROW_4x4_5x5(out2, tmp2, comm_fact0); |
| 2679 | OUTPUT_ROW_4x4_5x5(out3, tmp3, comm_fact0); |
| 2680 | OUTPUT_ROW_4x4_5x5(out4, tmp4, comm_fact0); |
| 2681 | OUTPUT_ROW_4x4_5x5(out5, tmp5, comm_fact0); |
| 2682 | OUTPUT_ROW_4x4_5x5(out6, tmp6, comm_fact0); |
| 2683 | OUTPUT_ROW_4x4_5x5(out7, tmp7, comm_fact0); |
| 2684 | |
| 2685 | // Store values across the 64 channels |
| 2686 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x * sizeof(float) + (y + z * (int)NUM_TILES_X) * dst_stride_y; |
| 2687 | |
| 2688 | *((__global float *)(dst_addr + 0 * dst_stride_z)) = out0.s0; |
| 2689 | *((__global float *)(dst_addr + 1 * dst_stride_z)) = out0.s1; |
| 2690 | *((__global float *)(dst_addr + 2 * dst_stride_z)) = out0.s2; |
| 2691 | *((__global float *)(dst_addr + 3 * dst_stride_z)) = out0.s3; |
| 2692 | *((__global float *)(dst_addr + 4 * dst_stride_z)) = out0.s4; |
| 2693 | *((__global float *)(dst_addr + 5 * dst_stride_z)) = out0.s5; |
| 2694 | *((__global float *)(dst_addr + 6 * dst_stride_z)) = out0.s6; |
| 2695 | *((__global float *)(dst_addr + 7 * dst_stride_z)) = out0.s7; |
| 2696 | *((__global float *)(dst_addr + 8 * dst_stride_z)) = out1.s0; |
| 2697 | *((__global float *)(dst_addr + 9 * dst_stride_z)) = out1.s1; |
| 2698 | *((__global float *)(dst_addr + 10 * dst_stride_z)) = out1.s2; |
| 2699 | *((__global float *)(dst_addr + 11 * dst_stride_z)) = out1.s3; |
| 2700 | *((__global float *)(dst_addr + 12 * dst_stride_z)) = out1.s4; |
| 2701 | *((__global float *)(dst_addr + 13 * dst_stride_z)) = out1.s5; |
| 2702 | *((__global float *)(dst_addr + 14 * dst_stride_z)) = out1.s6; |
| 2703 | *((__global float *)(dst_addr + 15 * dst_stride_z)) = out1.s7; |
| 2704 | *((__global float *)(dst_addr + 16 * dst_stride_z)) = out2.s0; |
| 2705 | *((__global float *)(dst_addr + 17 * dst_stride_z)) = out2.s1; |
| 2706 | *((__global float *)(dst_addr + 18 * dst_stride_z)) = out2.s2; |
| 2707 | *((__global float *)(dst_addr + 19 * dst_stride_z)) = out2.s3; |
| 2708 | *((__global float *)(dst_addr + 20 * dst_stride_z)) = out2.s4; |
| 2709 | *((__global float *)(dst_addr + 21 * dst_stride_z)) = out2.s5; |
| 2710 | *((__global float *)(dst_addr + 22 * dst_stride_z)) = out2.s6; |
| 2711 | *((__global float *)(dst_addr + 23 * dst_stride_z)) = out2.s7; |
| 2712 | *((__global float *)(dst_addr + 24 * dst_stride_z)) = out3.s0; |
| 2713 | *((__global float *)(dst_addr + 25 * dst_stride_z)) = out3.s1; |
| 2714 | *((__global float *)(dst_addr + 26 * dst_stride_z)) = out3.s2; |
| 2715 | *((__global float *)(dst_addr + 27 * dst_stride_z)) = out3.s3; |
| 2716 | *((__global float *)(dst_addr + 28 * dst_stride_z)) = out3.s4; |
| 2717 | *((__global float *)(dst_addr + 29 * dst_stride_z)) = out3.s5; |
| 2718 | *((__global float *)(dst_addr + 30 * dst_stride_z)) = out3.s6; |
| 2719 | *((__global float *)(dst_addr + 31 * dst_stride_z)) = out3.s7; |
| 2720 | *((__global float *)(dst_addr + 32 * dst_stride_z)) = out4.s0; |
| 2721 | *((__global float *)(dst_addr + 33 * dst_stride_z)) = out4.s1; |
| 2722 | *((__global float *)(dst_addr + 34 * dst_stride_z)) = out4.s2; |
| 2723 | *((__global float *)(dst_addr + 35 * dst_stride_z)) = out4.s3; |
| 2724 | *((__global float *)(dst_addr + 36 * dst_stride_z)) = out4.s4; |
| 2725 | *((__global float *)(dst_addr + 37 * dst_stride_z)) = out4.s5; |
| 2726 | *((__global float *)(dst_addr + 38 * dst_stride_z)) = out4.s6; |
| 2727 | *((__global float *)(dst_addr + 39 * dst_stride_z)) = out4.s7; |
| 2728 | *((__global float *)(dst_addr + 40 * dst_stride_z)) = out5.s0; |
| 2729 | *((__global float *)(dst_addr + 41 * dst_stride_z)) = out5.s1; |
| 2730 | *((__global float *)(dst_addr + 42 * dst_stride_z)) = out5.s2; |
| 2731 | *((__global float *)(dst_addr + 43 * dst_stride_z)) = out5.s3; |
| 2732 | *((__global float *)(dst_addr + 44 * dst_stride_z)) = out5.s4; |
| 2733 | *((__global float *)(dst_addr + 45 * dst_stride_z)) = out5.s5; |
| 2734 | *((__global float *)(dst_addr + 46 * dst_stride_z)) = out5.s6; |
| 2735 | *((__global float *)(dst_addr + 47 * dst_stride_z)) = out5.s7; |
| 2736 | *((__global float *)(dst_addr + 48 * dst_stride_z)) = out6.s0; |
| 2737 | *((__global float *)(dst_addr + 49 * dst_stride_z)) = out6.s1; |
| 2738 | *((__global float *)(dst_addr + 50 * dst_stride_z)) = out6.s2; |
| 2739 | *((__global float *)(dst_addr + 51 * dst_stride_z)) = out6.s3; |
| 2740 | *((__global float *)(dst_addr + 52 * dst_stride_z)) = out6.s4; |
| 2741 | *((__global float *)(dst_addr + 53 * dst_stride_z)) = out6.s5; |
| 2742 | *((__global float *)(dst_addr + 54 * dst_stride_z)) = out6.s6; |
| 2743 | *((__global float *)(dst_addr + 55 * dst_stride_z)) = out6.s7; |
| 2744 | *((__global float *)(dst_addr + 56 * dst_stride_z)) = out7.s0; |
| 2745 | *((__global float *)(dst_addr + 57 * dst_stride_z)) = out7.s1; |
| 2746 | *((__global float *)(dst_addr + 58 * dst_stride_z)) = out7.s2; |
| 2747 | *((__global float *)(dst_addr + 59 * dst_stride_z)) = out7.s3; |
| 2748 | *((__global float *)(dst_addr + 60 * dst_stride_z)) = out7.s4; |
| 2749 | *((__global float *)(dst_addr + 61 * dst_stride_z)) = out7.s5; |
| 2750 | *((__global float *)(dst_addr + 62 * dst_stride_z)) = out7.s6; |
| 2751 | *((__global float *)(dst_addr + 63 * dst_stride_z)) = out7.s7; |
| 2752 | } |
| 2753 | #endif // defined(SRC_DIM_1) && defined(SRC_DIM_2) |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2754 | #endif // defined(NUM_TILES_X) && defined(PAD_LEFT) && defined(PAD_TOP) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H) |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2755 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2756 | #if defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H) |
| 2757 | /** This OpenCL kernel performs Winograd output transform when the output tile is 2x2/2x1 or 1x2, the filter size 3x3/3x1 or 1x3 and the data layout is NCHW |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2758 | * |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2759 | * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2760 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2 |
| 2761 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2 |
| 2762 | * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 2763 | * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2764 | * |
| 2765 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 2766 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 2767 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2768 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 2769 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2770 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2771 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2772 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 2773 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 2774 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2775 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2776 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2777 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2778 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2779 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 2780 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2781 | */ |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2782 | __kernel void winograd_output_transform_2x2_3x3_nchw( |
| 2783 | TENSOR3D_DECLARATION(src), |
| 2784 | TENSOR3D_DECLARATION(dst) |
| 2785 | #if defined(HAS_BIAS) |
| 2786 | , |
| 2787 | VECTOR_DECLARATION(bias) |
| 2788 | #endif // defined(HAS_BIAS) |
| 2789 | ) |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2790 | { |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2791 | // Each thread stores a 2x2/2x1 or 1x2 tile accordingly with the filter size |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2792 | Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2793 | |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2794 | const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0); |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2795 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2796 | // Load the values across the 16 or 4 channels to compose the 4x4 or 4x1 tile |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2797 | float d00 = *((__global float *)(src_addr + 0 * src_stride_z)); |
| 2798 | float d01 = *((__global float *)(src_addr + 1 * src_stride_z)); |
| 2799 | float d02 = *((__global float *)(src_addr + 2 * src_stride_z)); |
| 2800 | float d03 = *((__global float *)(src_addr + 3 * src_stride_z)); |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2801 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2802 | #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 2803 | // Compute the 2x1 or 1x2 output tile |
| 2804 | // out00 = d00 + d01 + d02 |
| 2805 | // out01 = d01 - d02 - d03 |
| 2806 | |
| 2807 | float out00 = d00 + d01 + d02; |
| 2808 | float out01 = d01 - d02 - d03; |
| 2809 | #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 2810 | float d10 = *((__global float *)(src_addr + 4 * src_stride_z)); |
| 2811 | float d11 = *((__global float *)(src_addr + 5 * src_stride_z)); |
| 2812 | float d12 = *((__global float *)(src_addr + 6 * src_stride_z)); |
| 2813 | float d13 = *((__global float *)(src_addr + 7 * src_stride_z)); |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2814 | |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2815 | float d20 = *((__global float *)(src_addr + 8 * src_stride_z)); |
| 2816 | float d21 = *((__global float *)(src_addr + 9 * src_stride_z)); |
| 2817 | float d22 = *((__global float *)(src_addr + 10 * src_stride_z)); |
| 2818 | float d23 = *((__global float *)(src_addr + 11 * src_stride_z)); |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2819 | |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2820 | float d30 = *((__global float *)(src_addr + 12 * src_stride_z)); |
| 2821 | float d31 = *((__global float *)(src_addr + 13 * src_stride_z)); |
| 2822 | float d32 = *((__global float *)(src_addr + 14 * src_stride_z)); |
| 2823 | float d33 = *((__global float *)(src_addr + 15 * src_stride_z)); |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2824 | |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2825 | // Compute the 2x2 output tile |
| 2826 | float k0 = d01 + d11 + d21; |
| 2827 | float k1 = d02 + d12 + d22; |
| 2828 | float k2 = d11 - d21 - d31; |
| 2829 | float k3 = d12 - d22 - d32; |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2830 | |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2831 | // out00 = d00 + d10 + d20 + d01 + d11 + d21 + d02 + d12 + d22 |
| 2832 | // out01 = d01 + d11 + d21 - (d02 + d12 + d22) - (d03 + d13 + d23) |
| 2833 | // out10 = d10 - d20 - d30 + (d11 - d21 - d31) + (d12 - d22 - d32) |
| 2834 | // out11 = d11 - d21 - d31 - (d12 - d22 - d32) - (d13 - d23 - d33) |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2835 | |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2836 | float out00 = d10; |
| 2837 | float out01 = -d13; |
| 2838 | float out10 = d10; |
| 2839 | float out11 = -d13; |
| 2840 | |
| 2841 | out00 += d00 + d20 + k0 + k1; |
| 2842 | out01 += k0 - k1 - (d03 + d23); |
| 2843 | out10 += -d20 - d30 + k2 + k3; |
| 2844 | out11 += k2 - k3 + d23 + d33; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2845 | #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2846 | |
| 2847 | int y_in = get_global_id(1); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2848 | int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W; |
| 2849 | int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H; |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2850 | int z_out = get_global_id(0); |
| 2851 | |
| 2852 | #if defined(HAS_BIAS) |
| 2853 | // Add bias |
| 2854 | Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias); |
| 2855 | |
| 2856 | float b = (float) * ((__global float *)(vector_offset(&bias, z_out))); |
| 2857 | |
| 2858 | out00 += (float)b; |
| 2859 | out01 += (float)b; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2860 | #endif // defined(HAS_BIAS) |
| 2861 | |
| 2862 | // Get output address |
| 2863 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z; |
| 2864 | |
| 2865 | // Store the output tile |
| 2866 | #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 2867 | *((__global float *)(dst_addr + 0 * dst_stride_y)) = out00; |
| 2868 | *((__global float *)(dst_addr + 1 * dst_stride_y)) = out01; |
| 2869 | #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 2870 | vstore2((float2)(out00, out01), 0, (__global float *)(dst_addr + 0 * dst_stride_y)); |
| 2871 | #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 2872 | |
| 2873 | #if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 2874 | #if defined(HAS_BIAS) |
| 2875 | // Add bias |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2876 | out10 += (float)b; |
| 2877 | out11 += (float)b; |
| 2878 | #endif // defined(HAS_BIAS) |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2879 | |
Gian Marco Iodice | d2fab73 | 2018-03-02 11:18:12 +0000 | [diff] [blame] | 2880 | vstore2((float2)(out10, out11), 0, (__global float *)(dst_addr + 1 * dst_stride_y)); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2881 | #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
Gian Marco Iodice | 7e4b239 | 2018-02-22 16:17:20 +0000 | [diff] [blame] | 2882 | } |
Giorgio Arena | dd03870 | 2018-04-16 11:20:11 +0100 | [diff] [blame] | 2883 | |
Giorgio Arena | c42f28d | 2018-04-26 11:33:05 +0100 | [diff] [blame] | 2884 | /** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data layout is NCHW |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 2885 | * |
| 2886 | * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16 |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2887 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4 |
| 2888 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4 |
| 2889 | * @note If this kernel is used to perform Winograd output transform 3x1, -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 2890 | * @note If this kernel is used to perform Winograd output transform 1x3, -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 2891 | * |
| 2892 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 2893 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 2894 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 2895 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 2896 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2897 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2898 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 2899 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 2900 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 2901 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 2902 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 2903 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 2904 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 2905 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 2906 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 2907 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 2908 | */ |
| 2909 | __kernel void winograd_output_transform_4x4_3x3_nchw( |
| 2910 | TENSOR3D_DECLARATION(src), |
| 2911 | TENSOR3D_DECLARATION(dst) |
| 2912 | #if defined(HAS_BIAS) |
| 2913 | , |
| 2914 | VECTOR_DECLARATION(bias) |
| 2915 | #endif // defined(HAS_BIAS) |
| 2916 | ) |
| 2917 | { |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2918 | // Each thread stores a 4x4/4x1 or 1x4 tile |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 2919 | Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); |
| 2920 | |
| 2921 | const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0); |
| 2922 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2923 | // Load the values across the channels to compose the 6x6 or 6x1 tile |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 2924 | float d00 = *((__global float *)(src_addr + 0 * src_stride_z)); |
| 2925 | float d01 = *((__global float *)(src_addr + 1 * src_stride_z)); |
| 2926 | float d02 = *((__global float *)(src_addr + 2 * src_stride_z)); |
| 2927 | float d03 = *((__global float *)(src_addr + 3 * src_stride_z)); |
| 2928 | float d04 = *((__global float *)(src_addr + 4 * src_stride_z)); |
| 2929 | float d05 = *((__global float *)(src_addr + 5 * src_stride_z)); |
| 2930 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 2931 | #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 2932 | // Compute out00, out01, out02 and out03 |
| 2933 | float out00 = d00 + d01 + d02 + d03 + d04; |
| 2934 | float out01 = d01 - d02 + 2.0f * d03 - 2.0f * d04; |
| 2935 | float out02 = d01 + d02 + 4.0f * d03 + 4.0f * d04; |
| 2936 | float out03 = d01 - d02 + 8.0f * d03 - 8.0f * d04 + d05; |
| 2937 | #else // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 2938 | float d10 = *((__global float *)(src_addr + 6 * src_stride_z)); |
| 2939 | float d11 = *((__global float *)(src_addr + 7 * src_stride_z)); |
| 2940 | float d12 = *((__global float *)(src_addr + 8 * src_stride_z)); |
| 2941 | float d13 = *((__global float *)(src_addr + 9 * src_stride_z)); |
| 2942 | float d14 = *((__global float *)(src_addr + 10 * src_stride_z)); |
| 2943 | float d15 = *((__global float *)(src_addr + 11 * src_stride_z)); |
| 2944 | |
| 2945 | float d20 = *((__global float *)(src_addr + 12 * src_stride_z)); |
| 2946 | float d21 = *((__global float *)(src_addr + 13 * src_stride_z)); |
| 2947 | float d22 = *((__global float *)(src_addr + 14 * src_stride_z)); |
| 2948 | float d23 = *((__global float *)(src_addr + 15 * src_stride_z)); |
| 2949 | float d24 = *((__global float *)(src_addr + 16 * src_stride_z)); |
| 2950 | float d25 = *((__global float *)(src_addr + 17 * src_stride_z)); |
| 2951 | |
| 2952 | float d30 = *((__global float *)(src_addr + 18 * src_stride_z)); |
| 2953 | float d31 = *((__global float *)(src_addr + 19 * src_stride_z)); |
| 2954 | float d32 = *((__global float *)(src_addr + 20 * src_stride_z)); |
| 2955 | float d33 = *((__global float *)(src_addr + 21 * src_stride_z)); |
| 2956 | float d34 = *((__global float *)(src_addr + 22 * src_stride_z)); |
| 2957 | float d35 = *((__global float *)(src_addr + 23 * src_stride_z)); |
| 2958 | |
| 2959 | float d40 = *((__global float *)(src_addr + 24 * src_stride_z)); |
| 2960 | float d41 = *((__global float *)(src_addr + 25 * src_stride_z)); |
| 2961 | float d42 = *((__global float *)(src_addr + 26 * src_stride_z)); |
| 2962 | float d43 = *((__global float *)(src_addr + 27 * src_stride_z)); |
| 2963 | float d44 = *((__global float *)(src_addr + 28 * src_stride_z)); |
| 2964 | float d45 = *((__global float *)(src_addr + 29 * src_stride_z)); |
| 2965 | |
| 2966 | float d50 = *((__global float *)(src_addr + 30 * src_stride_z)); |
| 2967 | float d51 = *((__global float *)(src_addr + 31 * src_stride_z)); |
| 2968 | float d52 = *((__global float *)(src_addr + 32 * src_stride_z)); |
| 2969 | float d53 = *((__global float *)(src_addr + 33 * src_stride_z)); |
| 2970 | float d54 = *((__global float *)(src_addr + 34 * src_stride_z)); |
| 2971 | float d55 = *((__global float *)(src_addr + 35 * src_stride_z)); |
| 2972 | |
| 2973 | // Compute out00, out01, out02 and out03 |
| 2974 | float out00 = d01 + d21 + d41 + d11 + d31; |
| 2975 | float out01 = d01 + d21 + d41 + d11 + d31; |
| 2976 | float out02 = d01 + d21 + d41 + d11 + d31; |
| 2977 | float out03 = d01 + d21 + d41 + d11 + d31; |
| 2978 | |
| 2979 | float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44; |
| 2980 | float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44; |
| 2981 | |
| 2982 | out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42; |
| 2983 | out01 += k1 - d02 - d12 - d22 - d32 - d42; |
| 2984 | out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42; |
| 2985 | out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45; |
| 2986 | |
| 2987 | // Compute out10, out11, out12 and out13 |
| 2988 | float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41; |
| 2989 | float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41; |
| 2990 | float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41; |
| 2991 | float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41; |
| 2992 | |
| 2993 | k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44; |
| 2994 | k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44; |
| 2995 | |
| 2996 | out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42; |
| 2997 | out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42; |
| 2998 | out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42; |
| 2999 | out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45; |
| 3000 | |
| 3001 | // Compute out20, out21, out22 and out23 |
| 3002 | float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41; |
| 3003 | float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41; |
| 3004 | float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41; |
| 3005 | float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41; |
| 3006 | |
| 3007 | k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44; |
| 3008 | k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44; |
| 3009 | |
| 3010 | out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42; |
| 3011 | out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42; |
| 3012 | out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42; |
| 3013 | out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45; |
| 3014 | |
| 3015 | // Compute out30, out31, out32 and out33 |
| 3016 | float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51; |
| 3017 | float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51; |
| 3018 | float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51; |
| 3019 | float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51; |
| 3020 | |
| 3021 | k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54; |
| 3022 | k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54; |
| 3023 | |
| 3024 | out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52; |
| 3025 | out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52; |
| 3026 | out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52; |
| 3027 | out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 3028 | #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) || defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 3029 | |
| 3030 | int y_in = get_global_id(1); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 3031 | int x_out = (y_in % NUM_TILES_X) * OUTPUT_TILE_W; |
| 3032 | int y_out = (y_in / NUM_TILES_X) * OUTPUT_TILE_H; |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 3033 | int z_out = get_global_id(0); |
| 3034 | |
| 3035 | #if defined(HAS_BIAS) |
| 3036 | // Add bias |
| 3037 | Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias); |
| 3038 | |
| 3039 | float b = (float) * ((__global float *)(vector_offset(&bias, z_out))); |
| 3040 | |
| 3041 | out00 += (float)b; |
| 3042 | out01 += (float)b; |
| 3043 | out02 += (float)b; |
| 3044 | out03 += (float)b; |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 3045 | #endif // defined(HAS_BIAS) |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 3046 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 3047 | // Get output address |
| 3048 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z; |
| 3049 | |
| 3050 | // Store the output tile |
| 3051 | #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 3052 | *((__global float *)(dst_addr + 0 * dst_stride_y)) = out00; |
| 3053 | *((__global float *)(dst_addr + 1 * dst_stride_y)) = out01; |
| 3054 | *((__global float *)(dst_addr + 2 * dst_stride_y)) = out02; |
| 3055 | *((__global float *)(dst_addr + 3 * dst_stride_y)) = out03; |
| 3056 | #else // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 3057 | vstore4((float4)(out00, out01, out02, out03), 0, (__global float *)(dst_addr + 0 * dst_stride_y)); |
| 3058 | #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 3059 | |
| 3060 | #if !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 3061 | #if defined(HAS_BIAS) |
| 3062 | // Add bias |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 3063 | out10 += (float)b; |
| 3064 | out11 += (float)b; |
| 3065 | out12 += (float)b; |
| 3066 | out13 += (float)b; |
| 3067 | |
| 3068 | out20 += (float)b; |
| 3069 | out21 += (float)b; |
| 3070 | out22 += (float)b; |
| 3071 | out23 += (float)b; |
| 3072 | |
| 3073 | out30 += (float)b; |
| 3074 | out31 += (float)b; |
| 3075 | out32 += (float)b; |
| 3076 | out33 += (float)b; |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 3077 | #endif // defined(HAS_BIAS) |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 3078 | vstore4((float4)(out10, out11, out12, out13), 0, (__global float *)(dst_addr + 1 * dst_stride_y)); |
| 3079 | vstore4((float4)(out20, out21, out22, out23), 0, (__global float *)(dst_addr + 2 * dst_stride_y)); |
| 3080 | vstore4((float4)(out30, out31, out32, out33), 0, (__global float *)(dst_addr + 3 * dst_stride_y)); |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 3081 | #endif // !defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) && !defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
Gian Marco Iodice | e52a300 | 2018-04-11 15:59:10 +0100 | [diff] [blame] | 3082 | } |
| 3083 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 3084 | #if defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) |
| 3085 | /** This OpenCL kernel performs Winograd output transform when the output tile is 2x1, the filter size 3x1 and the data layout is NCHW |
| 3086 | * |
| 3087 | * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16 |
| 3088 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=2 |
| 3089 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1 |
| 3090 | * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 3091 | * |
| 3092 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 3093 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 3094 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 3095 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 3096 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3097 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3098 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3099 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 3100 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 3101 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 3102 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 3103 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 3104 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3105 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3106 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3107 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 3108 | */ |
| 3109 | __kernel void winograd_output_transform_2x1_3x1_nchw( |
| 3110 | TENSOR3D_DECLARATION(src), |
| 3111 | TENSOR3D_DECLARATION(dst) |
| 3112 | #if defined(HAS_BIAS) |
| 3113 | , |
| 3114 | VECTOR_DECLARATION(bias) |
| 3115 | #endif // defined(HAS_BIAS) |
| 3116 | ) |
| 3117 | { |
| 3118 | winograd_output_transform_2x2_3x3_nchw(src_ptr, |
| 3119 | src_stride_x, |
| 3120 | src_step_x, |
| 3121 | src_stride_y, |
| 3122 | src_step_y, |
| 3123 | src_stride_z, |
| 3124 | src_step_z, |
| 3125 | src_offset_first_element_in_bytes, |
| 3126 | dst_ptr, |
| 3127 | dst_stride_x, |
| 3128 | dst_step_x, |
| 3129 | dst_stride_y, |
| 3130 | dst_step_y, |
| 3131 | dst_stride_z, |
| 3132 | dst_step_z, |
| 3133 | dst_offset_first_element_in_bytes |
| 3134 | #if defined(HAS_BIAS) |
| 3135 | , |
| 3136 | bias_ptr, |
| 3137 | bias_stride_x, |
| 3138 | bias_step_x, |
| 3139 | bias_offset_first_element_in_bytes |
| 3140 | #endif // defined(HAS_BIAS) |
| 3141 | ); |
| 3142 | } |
| 3143 | |
| 3144 | /** This OpenCL kernel performs Winograd output transform when the output tile is 4x1, the filter size 3x1 and the data layout is NCHW |
| 3145 | * |
| 3146 | * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16 |
| 3147 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=4 |
| 3148 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=1 |
| 3149 | * @note -DWINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL has to be passed at compile time |
| 3150 | * |
| 3151 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 3152 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 3153 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 3154 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 3155 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3156 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3157 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3158 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 3159 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 3160 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 3161 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 3162 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 3163 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3164 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3165 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3166 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 3167 | */ |
| 3168 | __kernel void winograd_output_transform_4x1_3x1_nchw( |
| 3169 | TENSOR3D_DECLARATION(src), |
| 3170 | TENSOR3D_DECLARATION(dst) |
| 3171 | #if defined(HAS_BIAS) |
| 3172 | , |
| 3173 | VECTOR_DECLARATION(bias) |
| 3174 | #endif // defined(HAS_BIAS) |
| 3175 | ) |
| 3176 | { |
| 3177 | winograd_output_transform_4x4_3x3_nchw(src_ptr, |
| 3178 | src_stride_x, |
| 3179 | src_step_x, |
| 3180 | src_stride_y, |
| 3181 | src_step_y, |
| 3182 | src_stride_z, |
| 3183 | src_step_z, |
| 3184 | src_offset_first_element_in_bytes, |
| 3185 | dst_ptr, |
| 3186 | dst_stride_x, |
| 3187 | dst_step_x, |
| 3188 | dst_stride_y, |
| 3189 | dst_step_y, |
| 3190 | dst_stride_z, |
| 3191 | dst_step_z, |
| 3192 | dst_offset_first_element_in_bytes |
| 3193 | #if defined(HAS_BIAS) |
| 3194 | , |
| 3195 | bias_ptr, |
| 3196 | bias_stride_x, |
| 3197 | bias_step_x, |
| 3198 | bias_offset_first_element_in_bytes |
| 3199 | #endif // defined(HAS_BIAS) |
| 3200 | ); |
| 3201 | } |
| 3202 | #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_HORIZONTAL) |
| 3203 | |
| 3204 | #if defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 3205 | /** This OpenCL kernel performs Winograd output transform when the output tile is 1x2, the filter size 1x3 and the data layout is NCHW |
| 3206 | * |
| 3207 | * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16 |
| 3208 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1 |
| 3209 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=2 |
| 3210 | * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time |
| 3211 | * |
| 3212 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 3213 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 3214 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 3215 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 3216 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3217 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3218 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3219 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 3220 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 3221 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 3222 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 3223 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 3224 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3225 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3226 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3227 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 3228 | */ |
| 3229 | __kernel void winograd_output_transform_1x2_1x3_nchw( |
| 3230 | TENSOR3D_DECLARATION(src), |
| 3231 | TENSOR3D_DECLARATION(dst) |
| 3232 | #if defined(HAS_BIAS) |
| 3233 | , |
| 3234 | VECTOR_DECLARATION(bias) |
| 3235 | #endif // defined(HAS_BIAS) |
| 3236 | ) |
| 3237 | { |
| 3238 | winograd_output_transform_2x2_3x3_nchw(src_ptr, |
| 3239 | src_stride_x, |
| 3240 | src_step_x, |
| 3241 | src_stride_y, |
| 3242 | src_step_y, |
| 3243 | src_stride_z, |
| 3244 | src_step_z, |
| 3245 | src_offset_first_element_in_bytes, |
| 3246 | dst_ptr, |
| 3247 | dst_stride_x, |
| 3248 | dst_step_x, |
| 3249 | dst_stride_y, |
| 3250 | dst_step_y, |
| 3251 | dst_stride_z, |
| 3252 | dst_step_z, |
| 3253 | dst_offset_first_element_in_bytes |
| 3254 | #if defined(HAS_BIAS) |
| 3255 | , |
| 3256 | bias_ptr, |
| 3257 | bias_stride_x, |
| 3258 | bias_step_x, |
| 3259 | bias_offset_first_element_in_bytes |
| 3260 | #endif // defined(HAS_BIAS) |
| 3261 | ); |
| 3262 | } |
| 3263 | |
| 3264 | /** This OpenCL kernel performs Winograd output transform when the output tile is 1x4, the filter size 1x3 and the data layout is NCHW |
| 3265 | * |
| 3266 | * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16 |
| 3267 | * @note The width of the output tile must be passed at compile time using -DOUTPUT_TILE_W: e.g. -DOUTPUT_TILE_W=1 |
| 3268 | * @note The height of the output tile must be passed at compile time using -DOUTPUT_TILE_H: e.g. -DOUTPUT_TILE_H=4 |
| 3269 | * @note -DWINOGRAD_OUTPUT_TRANSFORM_VERTICAL has to be passed at compile time |
| 3270 | * |
| 3271 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 3272 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 3273 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 3274 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 3275 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3276 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3277 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3278 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 3279 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 3280 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 3281 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 3282 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 3283 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3284 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3285 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3286 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 3287 | */ |
| 3288 | __kernel void winograd_output_transform_1x4_1x3_nchw( |
| 3289 | TENSOR3D_DECLARATION(src), |
| 3290 | TENSOR3D_DECLARATION(dst) |
| 3291 | #if defined(HAS_BIAS) |
| 3292 | , |
| 3293 | VECTOR_DECLARATION(bias) |
| 3294 | #endif // defined(HAS_BIAS) |
| 3295 | ) |
| 3296 | { |
| 3297 | winograd_output_transform_4x4_3x3_nchw(src_ptr, |
| 3298 | src_stride_x, |
| 3299 | src_step_x, |
| 3300 | src_stride_y, |
| 3301 | src_step_y, |
| 3302 | src_stride_z, |
| 3303 | src_step_z, |
| 3304 | src_offset_first_element_in_bytes, |
| 3305 | dst_ptr, |
| 3306 | dst_stride_x, |
| 3307 | dst_step_x, |
| 3308 | dst_stride_y, |
| 3309 | dst_step_y, |
| 3310 | dst_stride_z, |
| 3311 | dst_step_z, |
| 3312 | dst_offset_first_element_in_bytes |
| 3313 | #if defined(HAS_BIAS) |
| 3314 | , |
| 3315 | bias_ptr, |
| 3316 | bias_stride_x, |
| 3317 | bias_step_x, |
| 3318 | bias_offset_first_element_in_bytes |
| 3319 | #endif // defined(HAS_BIAS) |
| 3320 | ); |
| 3321 | } |
| 3322 | #endif // defined(WINOGRAD_OUTPUT_TRANSFORM_VERTICAL) |
| 3323 | |
Giorgio Arena | c42f28d | 2018-04-26 11:33:05 +0100 | [diff] [blame] | 3324 | /** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 3x3 and the data layout is NHWC |
Giorgio Arena | 3695f9a | 2018-04-23 17:41:22 +0100 | [diff] [blame] | 3325 | * |
| 3326 | * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16 |
| 3327 | * |
| 3328 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 3329 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 3330 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 3331 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 3332 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3333 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3334 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3335 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 3336 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 3337 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 3338 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 3339 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 3340 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3341 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3342 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3343 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
Georgios Pinitas | c084f0d | 2018-06-11 17:43:31 +0100 | [diff] [blame] | 3344 | * @param[in] dst_size Size of the destination tensor, minus the last padding |
Giorgio Arena | 3695f9a | 2018-04-23 17:41:22 +0100 | [diff] [blame] | 3345 | */ |
| 3346 | __kernel void winograd_output_transform_4x4_3x3_nhwc( |
| 3347 | TENSOR3D_DECLARATION(src), |
Georgios Pinitas | c084f0d | 2018-06-11 17:43:31 +0100 | [diff] [blame] | 3348 | TENSOR3D_DECLARATION(dst), |
Giorgio Arena | 3695f9a | 2018-04-23 17:41:22 +0100 | [diff] [blame] | 3349 | #if defined(HAS_BIAS) |
Georgios Pinitas | c084f0d | 2018-06-11 17:43:31 +0100 | [diff] [blame] | 3350 | VECTOR_DECLARATION(bias), |
Giorgio Arena | 3695f9a | 2018-04-23 17:41:22 +0100 | [diff] [blame] | 3351 | #endif // defined(HAS_BIAS) |
Georgios Pinitas | c084f0d | 2018-06-11 17:43:31 +0100 | [diff] [blame] | 3352 | int dst_size) |
Giorgio Arena | 3695f9a | 2018-04-23 17:41:22 +0100 | [diff] [blame] | 3353 | { |
| 3354 | // Each thread stores a 4x4 tile |
| 3355 | Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); |
| 3356 | |
| 3357 | const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0); |
| 3358 | |
| 3359 | // Load the values across the 36 channels to compose the 6x6 tile |
| 3360 | float d00 = *((__global float *)(src_addr + 0 * src_stride_z)); |
| 3361 | float d01 = *((__global float *)(src_addr + 1 * src_stride_z)); |
| 3362 | float d02 = *((__global float *)(src_addr + 2 * src_stride_z)); |
| 3363 | float d03 = *((__global float *)(src_addr + 3 * src_stride_z)); |
| 3364 | float d04 = *((__global float *)(src_addr + 4 * src_stride_z)); |
| 3365 | float d05 = *((__global float *)(src_addr + 5 * src_stride_z)); |
| 3366 | |
| 3367 | float d10 = *((__global float *)(src_addr + 6 * src_stride_z)); |
| 3368 | float d11 = *((__global float *)(src_addr + 7 * src_stride_z)); |
| 3369 | float d12 = *((__global float *)(src_addr + 8 * src_stride_z)); |
| 3370 | float d13 = *((__global float *)(src_addr + 9 * src_stride_z)); |
| 3371 | float d14 = *((__global float *)(src_addr + 10 * src_stride_z)); |
| 3372 | float d15 = *((__global float *)(src_addr + 11 * src_stride_z)); |
| 3373 | |
| 3374 | float d20 = *((__global float *)(src_addr + 12 * src_stride_z)); |
| 3375 | float d21 = *((__global float *)(src_addr + 13 * src_stride_z)); |
| 3376 | float d22 = *((__global float *)(src_addr + 14 * src_stride_z)); |
| 3377 | float d23 = *((__global float *)(src_addr + 15 * src_stride_z)); |
| 3378 | float d24 = *((__global float *)(src_addr + 16 * src_stride_z)); |
| 3379 | float d25 = *((__global float *)(src_addr + 17 * src_stride_z)); |
| 3380 | |
| 3381 | float d30 = *((__global float *)(src_addr + 18 * src_stride_z)); |
| 3382 | float d31 = *((__global float *)(src_addr + 19 * src_stride_z)); |
| 3383 | float d32 = *((__global float *)(src_addr + 20 * src_stride_z)); |
| 3384 | float d33 = *((__global float *)(src_addr + 21 * src_stride_z)); |
| 3385 | float d34 = *((__global float *)(src_addr + 22 * src_stride_z)); |
| 3386 | float d35 = *((__global float *)(src_addr + 23 * src_stride_z)); |
| 3387 | |
| 3388 | float d40 = *((__global float *)(src_addr + 24 * src_stride_z)); |
| 3389 | float d41 = *((__global float *)(src_addr + 25 * src_stride_z)); |
| 3390 | float d42 = *((__global float *)(src_addr + 26 * src_stride_z)); |
| 3391 | float d43 = *((__global float *)(src_addr + 27 * src_stride_z)); |
| 3392 | float d44 = *((__global float *)(src_addr + 28 * src_stride_z)); |
| 3393 | float d45 = *((__global float *)(src_addr + 29 * src_stride_z)); |
| 3394 | |
| 3395 | float d50 = *((__global float *)(src_addr + 30 * src_stride_z)); |
| 3396 | float d51 = *((__global float *)(src_addr + 31 * src_stride_z)); |
| 3397 | float d52 = *((__global float *)(src_addr + 32 * src_stride_z)); |
| 3398 | float d53 = *((__global float *)(src_addr + 33 * src_stride_z)); |
| 3399 | float d54 = *((__global float *)(src_addr + 34 * src_stride_z)); |
| 3400 | float d55 = *((__global float *)(src_addr + 35 * src_stride_z)); |
| 3401 | |
| 3402 | // Compute out00, out01, out02 and out03 |
| 3403 | float out00 = d01 + d21 + d41 + d11 + d31; |
| 3404 | float out01 = d01 + d21 + d41 + d11 + d31; |
| 3405 | float out02 = d01 + d21 + d41 + d11 + d31; |
| 3406 | float out03 = d01 + d21 + d41 + d11 + d31; |
| 3407 | |
| 3408 | float k0 = d03 + d04 + d13 + d14 + d23 + d24 + d33 + d34 + d43 + d44; |
| 3409 | float k1 = 2.0f * d03 - 2.0f * d04 + 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 2.0f * d33 - 2.0f * d34 + 2.0f * d43 - 2.0f * d44; |
| 3410 | |
| 3411 | out00 += k0 + d00 + d02 + d10 + d12 + d20 + d22 + d30 + d32 + d40 + d42; |
| 3412 | out01 += k1 - d02 - d12 - d22 - d32 - d42; |
| 3413 | out02 += 4.0f * k0 + d02 + d12 + d22 + d32 + d42; |
| 3414 | out03 += 4.0f * k1 - d02 - d12 - d22 - d32 - d42 + d05 + d15 + d25 + d35 + d45; |
| 3415 | |
| 3416 | // Compute out10, out11, out12 and out13 |
| 3417 | float out10 = d11 - d21 + 2.0f * d31 - 2.0f * d41; |
| 3418 | float out11 = d11 - d21 + 2.0f * d31 - 2.0f * d41; |
| 3419 | float out12 = d11 - d21 + 2.0f * d31 - 2.0f * d41; |
| 3420 | float out13 = d11 - d21 + 2.0f * d31 - 2.0f * d41; |
| 3421 | |
| 3422 | k0 = d13 + d14 - d23 - d24 + 2.0f * d33 + 2.0f * d34 - 2.0f * d43 - 2.0f * d44; |
| 3423 | k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 4.0f * d33 - 4.0f * d34 - 4.0f * d43 + 4.0f * d44; |
| 3424 | |
| 3425 | out10 += k0 + d10 + d12 - d20 - d22 + 2.0f * d30 + 2.0f * d32 - 2.0f * d40 - 2.0f * d42; |
| 3426 | out11 += k1 - d12 + d22 - 2.0f * d32 + 2.0f * d42; |
| 3427 | out12 += 4.0f * k0 + d12 - d22 + 2.0f * d32 - 2.0f * d42; |
| 3428 | out13 += 4.0f * k1 - d12 + d15 + d22 - d25 - 2.0f * d32 + 2.0f * d35 + 2.0f * d42 - 2.0f * d45; |
| 3429 | |
| 3430 | // Compute out20, out21, out22 and out23 |
| 3431 | float out20 = d11 + d21 + 4.0f * d31 + 4.0f * d41; |
| 3432 | float out21 = d11 + d21 + 4.0f * d31 + 4.0f * d41; |
| 3433 | float out22 = d11 + d21 + 4.0f * d31 + 4.0f * d41; |
| 3434 | float out23 = d11 + d21 + 4.0f * d31 + 4.0f * d41; |
| 3435 | |
| 3436 | k0 = d13 + d14 + d23 + d24 + 4.0f * d33 + 4.0f * d34 + 4.0f * d43 + 4.0f * d44; |
| 3437 | k1 = 2.0f * d13 - 2.0f * d14 + 2.0f * d23 - 2.0f * d24 + 8.0f * d33 - 8.0f * d34 + 8.0f * d43 - 8.0f * d44; |
| 3438 | |
| 3439 | out20 += k0 + d10 + d12 + d20 + d22 + 4.0f * d30 + 4.0f * d32 + 4.0f * d40 + 4.0f * d42; |
| 3440 | out21 += k1 - d12 - d22 - 4.0f * d32 - 4.0f * d42; |
| 3441 | out22 += 4.0f * k0 + d12 + d22 + 4.0f * d32 + 4.0f * d42; |
| 3442 | out23 += 4.0f * k1 - d12 + d15 - d22 + d25 - 4.0f * d32 + 4.0f * d35 - 4.0f * d42 + 4.0f * d45; |
| 3443 | |
| 3444 | // Compute out30, out31, out32 and out33 |
| 3445 | float out30 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51; |
| 3446 | float out31 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51; |
| 3447 | float out32 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51; |
| 3448 | float out33 = d11 - d21 + 8.0f * d31 - 8.0f * d41 + d51; |
| 3449 | |
| 3450 | k0 = d13 + d14 - d23 - d24 + 8.0f * d33 + 8.0f * d34 - 8.0f * d43 - 8.0f * d44 + d53 + d54; |
| 3451 | k1 = 2.0f * d13 - 2.0f * d14 - 2.0f * d23 + 2.0f * d24 + 16.0f * d33 - 16.0f * d34 - 16.0f * d43 + 16.0f * d44 + 2.0f * d53 - 2.0f * d54; |
| 3452 | |
| 3453 | out30 += k0 + d10 + d12 - d20 - d22 + 8.0f * d30 + 8.0f * d32 - 8.0f * d40 - 8.0f * d42 + d50 + d52; |
| 3454 | out31 += k1 - d12 + d22 - 8.0f * d32 + 8.0f * d42 - d52; |
| 3455 | out32 += 4.0f * k0 + d12 - d22 + 8.0f * d32 - 8.0f * d42 + d52; |
| 3456 | out33 += 4.0f * k1 - d12 + d15 + d22 - d25 - 8.0f * d32 + 8.0f * d35 + 8.0f * d42 - 8.0f * d45 - d52 + d55; |
| 3457 | |
| 3458 | int y_in = get_global_id(1); |
| 3459 | int x_out = get_global_id(0); |
| 3460 | int y_out = (y_in % NUM_TILES_X) * 4; |
| 3461 | int z_out = (y_in / NUM_TILES_X) * 4; |
| 3462 | |
| 3463 | #if defined(HAS_BIAS) |
| 3464 | // Add bias |
| 3465 | Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias); |
| 3466 | |
| 3467 | float b = (float) * ((__global float *)(vector_offset(&bias, z_out))); |
| 3468 | |
| 3469 | out00 += (float)b; |
| 3470 | out01 += (float)b; |
| 3471 | out02 += (float)b; |
| 3472 | out03 += (float)b; |
| 3473 | |
| 3474 | out10 += (float)b; |
| 3475 | out11 += (float)b; |
| 3476 | out12 += (float)b; |
| 3477 | out13 += (float)b; |
| 3478 | |
| 3479 | out20 += (float)b; |
| 3480 | out21 += (float)b; |
| 3481 | out22 += (float)b; |
| 3482 | out23 += (float)b; |
| 3483 | |
| 3484 | out30 += (float)b; |
| 3485 | out31 += (float)b; |
| 3486 | out32 += (float)b; |
| 3487 | out33 += (float)b; |
| 3488 | |
| 3489 | #endif // defined(HAS_BIAS) |
| 3490 | |
| 3491 | // Get output address |
Georgios Pinitas | c084f0d | 2018-06-11 17:43:31 +0100 | [diff] [blame] | 3492 | int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z); |
| 3493 | offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding). |
| 3494 | int4 mult_y = min(dst_size - offset, 1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise. |
Giorgio Arena | 3695f9a | 2018-04-23 17:41:22 +0100 | [diff] [blame] | 3495 | |
| 3496 | // Store the 4x4 output tile |
Georgios Pinitas | c084f0d | 2018-06-11 17:43:31 +0100 | [diff] [blame] | 3497 | *((__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s0)) = out00; |
| 3498 | *((__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s0)) = out01; |
| 3499 | *((__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s0)) = out02; |
| 3500 | *((__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s0)) = out03; |
| 3501 | *((__global float *)(dst_ptr + mult_y.s1 * 0 * dst_stride_y + offset.s1)) = out10; |
| 3502 | *((__global float *)(dst_ptr + mult_y.s1 * 1 * dst_stride_y + offset.s1)) = out11; |
| 3503 | *((__global float *)(dst_ptr + mult_y.s1 * 2 * dst_stride_y + offset.s1)) = out12; |
| 3504 | *((__global float *)(dst_ptr + mult_y.s1 * 3 * dst_stride_y + offset.s1)) = out13; |
| 3505 | *((__global float *)(dst_ptr + mult_y.s2 * 0 * dst_stride_y + offset.s2)) = out20; |
| 3506 | *((__global float *)(dst_ptr + mult_y.s2 * 1 * dst_stride_y + offset.s2)) = out21; |
| 3507 | *((__global float *)(dst_ptr + mult_y.s2 * 2 * dst_stride_y + offset.s2)) = out22; |
| 3508 | *((__global float *)(dst_ptr + mult_y.s2 * 3 * dst_stride_y + offset.s2)) = out23; |
| 3509 | *((__global float *)(dst_ptr + mult_y.s3 * 0 * dst_stride_y + offset.s3)) = out30; |
| 3510 | *((__global float *)(dst_ptr + mult_y.s3 * 1 * dst_stride_y + offset.s3)) = out31; |
| 3511 | *((__global float *)(dst_ptr + mult_y.s3 * 2 * dst_stride_y + offset.s3)) = out32; |
| 3512 | *((__global float *)(dst_ptr + mult_y.s3 * 3 * dst_stride_y + offset.s3)) = out33; |
Giorgio Arena | 3695f9a | 2018-04-23 17:41:22 +0100 | [diff] [blame] | 3513 | } |
| 3514 | |
Giorgio Arena | dd03870 | 2018-04-16 11:20:11 +0100 | [diff] [blame] | 3515 | #define COMPUTE_TMP_COL(col, d0, d1, d2, d3, d4, d5, d6, d7, comm_fact) \ |
| 3516 | ({ \ |
| 3517 | comm_fact.s0 = d1 + d2; \ |
| 3518 | comm_fact.s1 = d3 + d4; \ |
| 3519 | comm_fact.s2 = d5 + d6; \ |
| 3520 | \ |
| 3521 | col.s0 = comm_fact.s0 + comm_fact.s1 + 8.f * comm_fact.s2 + d0; \ |
| 3522 | col.s2 = comm_fact.s0 + 4.f * comm_fact.s1 + 2.f * comm_fact.s2; \ |
| 3523 | \ |
| 3524 | comm_fact.s0 = d1 - d2; \ |
| 3525 | comm_fact.s1 = d3 - d4; \ |
| 3526 | comm_fact.s2 = d5 - d6; \ |
| 3527 | \ |
| 3528 | col.s1 = comm_fact.s0 + 2.f * comm_fact.s1 + 4.f * comm_fact.s2; \ |
| 3529 | col.s3 = comm_fact.s0 + 8.f * comm_fact.s1 + comm_fact.s2 + d7; \ |
| 3530 | }) |
| 3531 | |
Giorgio Arena | c42f28d | 2018-04-26 11:33:05 +0100 | [diff] [blame] | 3532 | /** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 5x5 and the data layout is NCHW |
Giorgio Arena | dd03870 | 2018-04-16 11:20:11 +0100 | [diff] [blame] | 3533 | * |
| 3534 | * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16 |
| 3535 | * |
| 3536 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 3537 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 3538 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 3539 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 3540 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3541 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3542 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3543 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 3544 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 3545 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 3546 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 3547 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 3548 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3549 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3550 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3551 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 3552 | */ |
| 3553 | __kernel void winograd_output_transform_4x4_5x5_nchw( |
| 3554 | TENSOR3D_DECLARATION(src), |
| 3555 | TENSOR3D_DECLARATION(dst) |
| 3556 | #if defined(HAS_BIAS) |
| 3557 | , |
| 3558 | VECTOR_DECLARATION(bias) |
| 3559 | #endif // defined(HAS_BIAS) |
| 3560 | ) |
| 3561 | { |
| 3562 | // Each thread stores a 4x4 tile |
| 3563 | Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); |
| 3564 | |
| 3565 | const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0); |
| 3566 | |
| 3567 | // Load the values across the 64 channels to compose the 8x8 input tile |
| 3568 | float d00 = *((__global float *)(src_addr + 0 * src_stride_z)); |
| 3569 | float d01 = *((__global float *)(src_addr + 1 * src_stride_z)); |
| 3570 | float d02 = *((__global float *)(src_addr + 2 * src_stride_z)); |
| 3571 | float d03 = *((__global float *)(src_addr + 3 * src_stride_z)); |
| 3572 | float d04 = *((__global float *)(src_addr + 4 * src_stride_z)); |
| 3573 | float d05 = *((__global float *)(src_addr + 5 * src_stride_z)); |
| 3574 | float d06 = *((__global float *)(src_addr + 6 * src_stride_z)); |
| 3575 | float d07 = *((__global float *)(src_addr + 7 * src_stride_z)); |
| 3576 | |
| 3577 | float d10 = *((__global float *)(src_addr + 8 * src_stride_z)); |
| 3578 | float d11 = *((__global float *)(src_addr + 9 * src_stride_z)); |
| 3579 | float d12 = *((__global float *)(src_addr + 10 * src_stride_z)); |
| 3580 | float d13 = *((__global float *)(src_addr + 11 * src_stride_z)); |
| 3581 | float d14 = *((__global float *)(src_addr + 12 * src_stride_z)); |
| 3582 | float d15 = *((__global float *)(src_addr + 13 * src_stride_z)); |
| 3583 | float d16 = *((__global float *)(src_addr + 14 * src_stride_z)); |
| 3584 | float d17 = *((__global float *)(src_addr + 15 * src_stride_z)); |
| 3585 | |
| 3586 | float d20 = *((__global float *)(src_addr + 16 * src_stride_z)); |
| 3587 | float d21 = *((__global float *)(src_addr + 17 * src_stride_z)); |
| 3588 | float d22 = *((__global float *)(src_addr + 18 * src_stride_z)); |
| 3589 | float d23 = *((__global float *)(src_addr + 19 * src_stride_z)); |
| 3590 | float d24 = *((__global float *)(src_addr + 20 * src_stride_z)); |
| 3591 | float d25 = *((__global float *)(src_addr + 21 * src_stride_z)); |
| 3592 | float d26 = *((__global float *)(src_addr + 22 * src_stride_z)); |
| 3593 | float d27 = *((__global float *)(src_addr + 23 * src_stride_z)); |
| 3594 | |
| 3595 | float d30 = *((__global float *)(src_addr + 24 * src_stride_z)); |
| 3596 | float d31 = *((__global float *)(src_addr + 25 * src_stride_z)); |
| 3597 | float d32 = *((__global float *)(src_addr + 26 * src_stride_z)); |
| 3598 | float d33 = *((__global float *)(src_addr + 27 * src_stride_z)); |
| 3599 | float d34 = *((__global float *)(src_addr + 28 * src_stride_z)); |
| 3600 | float d35 = *((__global float *)(src_addr + 29 * src_stride_z)); |
| 3601 | float d36 = *((__global float *)(src_addr + 30 * src_stride_z)); |
| 3602 | float d37 = *((__global float *)(src_addr + 31 * src_stride_z)); |
| 3603 | |
| 3604 | float d40 = *((__global float *)(src_addr + 32 * src_stride_z)); |
| 3605 | float d41 = *((__global float *)(src_addr + 33 * src_stride_z)); |
| 3606 | float d42 = *((__global float *)(src_addr + 34 * src_stride_z)); |
| 3607 | float d43 = *((__global float *)(src_addr + 35 * src_stride_z)); |
| 3608 | float d44 = *((__global float *)(src_addr + 36 * src_stride_z)); |
| 3609 | float d45 = *((__global float *)(src_addr + 37 * src_stride_z)); |
| 3610 | float d46 = *((__global float *)(src_addr + 38 * src_stride_z)); |
| 3611 | float d47 = *((__global float *)(src_addr + 39 * src_stride_z)); |
| 3612 | |
| 3613 | float d50 = *((__global float *)(src_addr + 40 * src_stride_z)); |
| 3614 | float d51 = *((__global float *)(src_addr + 41 * src_stride_z)); |
| 3615 | float d52 = *((__global float *)(src_addr + 42 * src_stride_z)); |
| 3616 | float d53 = *((__global float *)(src_addr + 43 * src_stride_z)); |
| 3617 | float d54 = *((__global float *)(src_addr + 44 * src_stride_z)); |
| 3618 | float d55 = *((__global float *)(src_addr + 45 * src_stride_z)); |
| 3619 | float d56 = *((__global float *)(src_addr + 46 * src_stride_z)); |
| 3620 | float d57 = *((__global float *)(src_addr + 47 * src_stride_z)); |
| 3621 | |
| 3622 | float d60 = *((__global float *)(src_addr + 48 * src_stride_z)); |
| 3623 | float d61 = *((__global float *)(src_addr + 49 * src_stride_z)); |
| 3624 | float d62 = *((__global float *)(src_addr + 50 * src_stride_z)); |
| 3625 | float d63 = *((__global float *)(src_addr + 51 * src_stride_z)); |
| 3626 | float d64 = *((__global float *)(src_addr + 52 * src_stride_z)); |
| 3627 | float d65 = *((__global float *)(src_addr + 53 * src_stride_z)); |
| 3628 | float d66 = *((__global float *)(src_addr + 54 * src_stride_z)); |
| 3629 | float d67 = *((__global float *)(src_addr + 55 * src_stride_z)); |
| 3630 | |
| 3631 | float d70 = *((__global float *)(src_addr + 56 * src_stride_z)); |
| 3632 | float d71 = *((__global float *)(src_addr + 57 * src_stride_z)); |
| 3633 | float d72 = *((__global float *)(src_addr + 58 * src_stride_z)); |
| 3634 | float d73 = *((__global float *)(src_addr + 59 * src_stride_z)); |
| 3635 | float d74 = *((__global float *)(src_addr + 60 * src_stride_z)); |
| 3636 | float d75 = *((__global float *)(src_addr + 61 * src_stride_z)); |
| 3637 | float d76 = *((__global float *)(src_addr + 62 * src_stride_z)); |
| 3638 | float d77 = *((__global float *)(src_addr + 63 * src_stride_z)); |
| 3639 | |
| 3640 | // Compute the 8x4 intermediate tensor |
| 3641 | float4 comm_fact0, comm_fact1, comm_fact2; |
| 3642 | float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7; |
| 3643 | |
| 3644 | COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0); |
| 3645 | COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0); |
| 3646 | COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0); |
| 3647 | COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0); |
| 3648 | COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0); |
| 3649 | COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0); |
| 3650 | COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0); |
| 3651 | COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0); |
| 3652 | |
| 3653 | // Compute the 4x4 output tile |
| 3654 | comm_fact0 = tmp_col1 + tmp_col2; |
| 3655 | comm_fact1 = tmp_col3 + tmp_col4; |
| 3656 | comm_fact2 = tmp_col5 + tmp_col6; |
| 3657 | |
| 3658 | float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0; |
| 3659 | float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2; |
| 3660 | |
| 3661 | comm_fact0 = tmp_col1 - tmp_col2; |
| 3662 | comm_fact1 = tmp_col3 - tmp_col4; |
| 3663 | comm_fact2 = tmp_col5 - tmp_col6; |
| 3664 | |
| 3665 | float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2; |
| 3666 | float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7; |
| 3667 | |
| 3668 | int y_in = get_global_id(1); |
| 3669 | int x_out = (y_in % NUM_TILES_X) * 4; |
| 3670 | int y_out = (y_in / NUM_TILES_X) * 4; |
| 3671 | int z_out = get_global_id(0); |
| 3672 | |
| 3673 | #if defined(HAS_BIAS) |
| 3674 | // Add bias |
| 3675 | Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias); |
| 3676 | |
| 3677 | float b = (float) * ((__global float *)(vector_offset(&bias, z_out))); |
| 3678 | |
| 3679 | out_col0 += (float4)b; |
| 3680 | out_col1 += (float4)b; |
| 3681 | out_col2 += (float4)b; |
| 3682 | out_col3 += (float4)b; |
| 3683 | #endif // defined(HAS_BIAS) |
| 3684 | |
| 3685 | // Get output address |
| 3686 | __global uchar *dst_addr = dst_ptr + dst_offset_first_element_in_bytes + x_out * dst_stride_x + y_out * dst_stride_y + z_out * dst_stride_z; |
| 3687 | |
| 3688 | // Store the 4x4 output tile |
| 3689 | *(__global float *)(dst_addr + 0 * dst_stride_x + 0 * dst_stride_y) = out_col0.s0; |
| 3690 | *(__global float *)(dst_addr + 1 * dst_stride_x + 0 * dst_stride_y) = out_col1.s0; |
| 3691 | *(__global float *)(dst_addr + 2 * dst_stride_x + 0 * dst_stride_y) = out_col2.s0; |
| 3692 | *(__global float *)(dst_addr + 3 * dst_stride_x + 0 * dst_stride_y) = out_col3.s0; |
| 3693 | *(__global float *)(dst_addr + 0 * dst_stride_x + 1 * dst_stride_y) = out_col0.s1; |
| 3694 | *(__global float *)(dst_addr + 1 * dst_stride_x + 1 * dst_stride_y) = out_col1.s1; |
| 3695 | *(__global float *)(dst_addr + 2 * dst_stride_x + 1 * dst_stride_y) = out_col2.s1; |
| 3696 | *(__global float *)(dst_addr + 3 * dst_stride_x + 1 * dst_stride_y) = out_col3.s1; |
| 3697 | *(__global float *)(dst_addr + 0 * dst_stride_x + 2 * dst_stride_y) = out_col0.s2; |
| 3698 | *(__global float *)(dst_addr + 1 * dst_stride_x + 2 * dst_stride_y) = out_col1.s2; |
| 3699 | *(__global float *)(dst_addr + 2 * dst_stride_x + 2 * dst_stride_y) = out_col2.s2; |
| 3700 | *(__global float *)(dst_addr + 3 * dst_stride_x + 2 * dst_stride_y) = out_col3.s2; |
| 3701 | *(__global float *)(dst_addr + 0 * dst_stride_x + 3 * dst_stride_y) = out_col0.s3; |
| 3702 | *(__global float *)(dst_addr + 1 * dst_stride_x + 3 * dst_stride_y) = out_col1.s3; |
| 3703 | *(__global float *)(dst_addr + 2 * dst_stride_x + 3 * dst_stride_y) = out_col2.s3; |
| 3704 | *(__global float *)(dst_addr + 3 * dst_stride_x + 3 * dst_stride_y) = out_col3.s3; |
| 3705 | } |
Giorgio Arena | 7210fe8 | 2018-06-08 12:24:14 +0100 | [diff] [blame] | 3706 | |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 3707 | /** This OpenCL kernel performs Winograd output transform when the output tile is 4x4, the filter size 5x5 and the data layout is NHWC |
Giorgio Arena | 7210fe8 | 2018-06-08 12:24:14 +0100 | [diff] [blame] | 3708 | * |
| 3709 | * @note The number of tiles along the X direction must be passed at compile time using -DNUM_TILES_X: e.g. -DNUM_TILES_X=16 |
| 3710 | * |
| 3711 | * @param[in] src_ptr Pointer to the source tensor. Supported data types: F32 |
| 3712 | * @param[in] src_stride_x Stride of the source tensor in X dimension (in bytes) |
| 3713 | * @param[in] src_step_x src_stride_x * number of elements along X processed per workitem(in bytes) |
| 3714 | * @param[in] src_stride_y Stride of the source tensor in Y dimension (in bytes) |
| 3715 | * @param[in] src_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3716 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3717 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3718 | * @param[in] src_offset_first_element_in_bytes The offset of the first element in the source tensor |
| 3719 | * @param[out] dst_ptr Pointer to the destination tensor. Supported data types: same as @p src_ptr |
| 3720 | * @param[in] dst_stride_x Stride of the destination tensor in X dimension (in bytes) |
| 3721 | * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) |
| 3722 | * @param[in] dst_stride_y Stride of the destination tensor in Y dimension (in bytes) |
| 3723 | * @param[in] dst_step_y dst_stride_y * number of elements along Y processed per workitem(in bytes) |
| 3724 | * @param[in] src_stride_z Stride of the source tensor in Z dimension (in bytes) |
| 3725 | * @param[in] src_step_z src_stride_z * number of elements along Z processed per workitem(in bytes) |
| 3726 | * @param[in] dst_offset_first_element_in_bytes The offset of the first element in the destination tensor |
| 3727 | */ |
| 3728 | __kernel void winograd_output_transform_4x4_5x5_nhwc( |
| 3729 | TENSOR3D_DECLARATION(src), |
| 3730 | TENSOR3D_DECLARATION(dst), |
| 3731 | #if defined(HAS_BIAS) |
| 3732 | VECTOR_DECLARATION(bias), |
| 3733 | #endif // defined(HAS_BIAS) |
| 3734 | int dst_size) |
| 3735 | { |
| 3736 | // Each thread stores a 4x4 tile |
| 3737 | Tensor3D src = CONVERT_TO_TENSOR3D_STRUCT(src); |
| 3738 | |
| 3739 | const __global uchar *src_addr = tensor3D_offset(&src, 0, 0, 0); |
| 3740 | |
| 3741 | // Load the values across the 64 channels to compose the 8x8 input tile |
| 3742 | float d00 = *((__global float *)(src_addr + 0 * src_stride_z)); |
| 3743 | float d01 = *((__global float *)(src_addr + 1 * src_stride_z)); |
| 3744 | float d02 = *((__global float *)(src_addr + 2 * src_stride_z)); |
| 3745 | float d03 = *((__global float *)(src_addr + 3 * src_stride_z)); |
| 3746 | float d04 = *((__global float *)(src_addr + 4 * src_stride_z)); |
| 3747 | float d05 = *((__global float *)(src_addr + 5 * src_stride_z)); |
| 3748 | float d06 = *((__global float *)(src_addr + 6 * src_stride_z)); |
| 3749 | float d07 = *((__global float *)(src_addr + 7 * src_stride_z)); |
| 3750 | |
| 3751 | float d10 = *((__global float *)(src_addr + 8 * src_stride_z)); |
| 3752 | float d11 = *((__global float *)(src_addr + 9 * src_stride_z)); |
| 3753 | float d12 = *((__global float *)(src_addr + 10 * src_stride_z)); |
| 3754 | float d13 = *((__global float *)(src_addr + 11 * src_stride_z)); |
| 3755 | float d14 = *((__global float *)(src_addr + 12 * src_stride_z)); |
| 3756 | float d15 = *((__global float *)(src_addr + 13 * src_stride_z)); |
| 3757 | float d16 = *((__global float *)(src_addr + 14 * src_stride_z)); |
| 3758 | float d17 = *((__global float *)(src_addr + 15 * src_stride_z)); |
| 3759 | |
| 3760 | float d20 = *((__global float *)(src_addr + 16 * src_stride_z)); |
| 3761 | float d21 = *((__global float *)(src_addr + 17 * src_stride_z)); |
| 3762 | float d22 = *((__global float *)(src_addr + 18 * src_stride_z)); |
| 3763 | float d23 = *((__global float *)(src_addr + 19 * src_stride_z)); |
| 3764 | float d24 = *((__global float *)(src_addr + 20 * src_stride_z)); |
| 3765 | float d25 = *((__global float *)(src_addr + 21 * src_stride_z)); |
| 3766 | float d26 = *((__global float *)(src_addr + 22 * src_stride_z)); |
| 3767 | float d27 = *((__global float *)(src_addr + 23 * src_stride_z)); |
| 3768 | |
| 3769 | float d30 = *((__global float *)(src_addr + 24 * src_stride_z)); |
| 3770 | float d31 = *((__global float *)(src_addr + 25 * src_stride_z)); |
| 3771 | float d32 = *((__global float *)(src_addr + 26 * src_stride_z)); |
| 3772 | float d33 = *((__global float *)(src_addr + 27 * src_stride_z)); |
| 3773 | float d34 = *((__global float *)(src_addr + 28 * src_stride_z)); |
| 3774 | float d35 = *((__global float *)(src_addr + 29 * src_stride_z)); |
| 3775 | float d36 = *((__global float *)(src_addr + 30 * src_stride_z)); |
| 3776 | float d37 = *((__global float *)(src_addr + 31 * src_stride_z)); |
| 3777 | |
| 3778 | float d40 = *((__global float *)(src_addr + 32 * src_stride_z)); |
| 3779 | float d41 = *((__global float *)(src_addr + 33 * src_stride_z)); |
| 3780 | float d42 = *((__global float *)(src_addr + 34 * src_stride_z)); |
| 3781 | float d43 = *((__global float *)(src_addr + 35 * src_stride_z)); |
| 3782 | float d44 = *((__global float *)(src_addr + 36 * src_stride_z)); |
| 3783 | float d45 = *((__global float *)(src_addr + 37 * src_stride_z)); |
| 3784 | float d46 = *((__global float *)(src_addr + 38 * src_stride_z)); |
| 3785 | float d47 = *((__global float *)(src_addr + 39 * src_stride_z)); |
| 3786 | |
| 3787 | float d50 = *((__global float *)(src_addr + 40 * src_stride_z)); |
| 3788 | float d51 = *((__global float *)(src_addr + 41 * src_stride_z)); |
| 3789 | float d52 = *((__global float *)(src_addr + 42 * src_stride_z)); |
| 3790 | float d53 = *((__global float *)(src_addr + 43 * src_stride_z)); |
| 3791 | float d54 = *((__global float *)(src_addr + 44 * src_stride_z)); |
| 3792 | float d55 = *((__global float *)(src_addr + 45 * src_stride_z)); |
| 3793 | float d56 = *((__global float *)(src_addr + 46 * src_stride_z)); |
| 3794 | float d57 = *((__global float *)(src_addr + 47 * src_stride_z)); |
| 3795 | |
| 3796 | float d60 = *((__global float *)(src_addr + 48 * src_stride_z)); |
| 3797 | float d61 = *((__global float *)(src_addr + 49 * src_stride_z)); |
| 3798 | float d62 = *((__global float *)(src_addr + 50 * src_stride_z)); |
| 3799 | float d63 = *((__global float *)(src_addr + 51 * src_stride_z)); |
| 3800 | float d64 = *((__global float *)(src_addr + 52 * src_stride_z)); |
| 3801 | float d65 = *((__global float *)(src_addr + 53 * src_stride_z)); |
| 3802 | float d66 = *((__global float *)(src_addr + 54 * src_stride_z)); |
| 3803 | float d67 = *((__global float *)(src_addr + 55 * src_stride_z)); |
| 3804 | |
| 3805 | float d70 = *((__global float *)(src_addr + 56 * src_stride_z)); |
| 3806 | float d71 = *((__global float *)(src_addr + 57 * src_stride_z)); |
| 3807 | float d72 = *((__global float *)(src_addr + 58 * src_stride_z)); |
| 3808 | float d73 = *((__global float *)(src_addr + 59 * src_stride_z)); |
| 3809 | float d74 = *((__global float *)(src_addr + 60 * src_stride_z)); |
| 3810 | float d75 = *((__global float *)(src_addr + 61 * src_stride_z)); |
| 3811 | float d76 = *((__global float *)(src_addr + 62 * src_stride_z)); |
| 3812 | float d77 = *((__global float *)(src_addr + 63 * src_stride_z)); |
| 3813 | |
| 3814 | // Compute the 8x4 intermediate tensor |
| 3815 | float4 comm_fact0, comm_fact1, comm_fact2; |
| 3816 | float4 tmp_col0, tmp_col1, tmp_col2, tmp_col3, tmp_col4, tmp_col5, tmp_col6, tmp_col7; |
| 3817 | |
| 3818 | COMPUTE_TMP_COL(tmp_col0, d00, d10, d20, d30, d40, d50, d60, d70, comm_fact0); |
| 3819 | COMPUTE_TMP_COL(tmp_col1, d01, d11, d21, d31, d41, d51, d61, d71, comm_fact0); |
| 3820 | COMPUTE_TMP_COL(tmp_col2, d02, d12, d22, d32, d42, d52, d62, d72, comm_fact0); |
| 3821 | COMPUTE_TMP_COL(tmp_col3, d03, d13, d23, d33, d43, d53, d63, d73, comm_fact0); |
| 3822 | COMPUTE_TMP_COL(tmp_col4, d04, d14, d24, d34, d44, d54, d64, d74, comm_fact0); |
| 3823 | COMPUTE_TMP_COL(tmp_col5, d05, d15, d25, d35, d45, d55, d65, d75, comm_fact0); |
| 3824 | COMPUTE_TMP_COL(tmp_col6, d06, d16, d26, d36, d46, d56, d66, d76, comm_fact0); |
| 3825 | COMPUTE_TMP_COL(tmp_col7, d07, d17, d27, d37, d47, d57, d67, d77, comm_fact0); |
| 3826 | |
| 3827 | // Compute the 4x4 output tile |
| 3828 | comm_fact0 = tmp_col1 + tmp_col2; |
| 3829 | comm_fact1 = tmp_col3 + tmp_col4; |
| 3830 | comm_fact2 = tmp_col5 + tmp_col6; |
| 3831 | |
| 3832 | float4 out_col0 = comm_fact0 + comm_fact1 + 8.f * comm_fact2 + tmp_col0; |
| 3833 | float4 out_col2 = comm_fact0 + 4.f * comm_fact1 + 2.f * comm_fact2; |
| 3834 | |
| 3835 | comm_fact0 = tmp_col1 - tmp_col2; |
| 3836 | comm_fact1 = tmp_col3 - tmp_col4; |
| 3837 | comm_fact2 = tmp_col5 - tmp_col6; |
| 3838 | |
| 3839 | float4 out_col1 = comm_fact0 + 2.f * comm_fact1 + 4.f * comm_fact2; |
| 3840 | float4 out_col3 = comm_fact0 + 8.f * comm_fact1 + comm_fact2 + tmp_col7; |
| 3841 | |
| 3842 | int y_in = get_global_id(1); |
| 3843 | int x_out = get_global_id(0); |
| 3844 | int y_out = (y_in % NUM_TILES_X) * 4; |
| 3845 | int z_out = (y_in / NUM_TILES_X) * 4; |
| 3846 | |
| 3847 | #if defined(HAS_BIAS) |
| 3848 | // Add bias |
| 3849 | Vector bias = CONVERT_TO_VECTOR_STRUCT_NO_STEP(bias); |
| 3850 | |
| 3851 | float b = (float) * ((__global float *)(vector_offset(&bias, z_out))); |
| 3852 | |
| 3853 | out_col0 += (float4)b; |
| 3854 | out_col1 += (float4)b; |
| 3855 | out_col2 += (float4)b; |
| 3856 | out_col3 += (float4)b; |
| 3857 | #endif // defined(HAS_BIAS) |
| 3858 | |
| 3859 | // Get output address |
| 3860 | int4 offset = (int4)(dst_offset_first_element_in_bytes + x_out * sizeof(float) + y_out * dst_stride_y + z_out * dst_stride_z); |
| 3861 | offset = min(offset + (int4)(0, 1, 2, 3) * (int4)dst_stride_z, dst_size); // If address is beyond the last plane, clamp it to dst_size (which points to the last padding). |
| 3862 | int4 mult_y = min(dst_size - offset, 1); // If out of bound, we don't want to increase dst_stride_y, so we set the multiplier to 0. It will be 1 otherwise. |
| 3863 | |
| 3864 | // Store the 4x4 output tile |
| 3865 | *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s0) = out_col0.s0; |
| 3866 | *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s0) = out_col1.s0; |
| 3867 | *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s0) = out_col2.s0; |
| 3868 | *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s0) = out_col3.s0; |
| 3869 | *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s1) = out_col0.s1; |
| 3870 | *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s1) = out_col1.s1; |
| 3871 | *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s1) = out_col2.s1; |
| 3872 | *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s1) = out_col3.s1; |
| 3873 | *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s2) = out_col0.s2; |
| 3874 | *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s2) = out_col1.s2; |
| 3875 | *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s2) = out_col2.s2; |
| 3876 | *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s2) = out_col3.s2; |
| 3877 | *(__global float *)(dst_ptr + mult_y.s0 * 0 * dst_stride_y + offset.s3) = out_col0.s3; |
| 3878 | *(__global float *)(dst_ptr + mult_y.s0 * 1 * dst_stride_y + offset.s3) = out_col1.s3; |
| 3879 | *(__global float *)(dst_ptr + mult_y.s0 * 2 * dst_stride_y + offset.s3) = out_col2.s3; |
| 3880 | *(__global float *)(dst_ptr + mult_y.s0 * 3 * dst_stride_y + offset.s3) = out_col3.s3; |
| 3881 | } |
Gian Marco Iodice | f1c2bf0 | 2018-06-13 14:05:54 +0100 | [diff] [blame^] | 3882 | #endif // defined(NUM_TILES_X) && defined(OUTPUT_TILE_W) && defined(OUTPUT_TILE_H) |