blob: 58bed71a4717a450ebe8d3854973d8d3b67dfb86 [file] [log] [blame]
Pablo Tellobda6e4b2018-08-22 11:40:33 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25#include "arm_compute/core/NEON/kernels/convolution/winograd/transforms/output.hpp"
Pablo Tellod3d97d22018-10-05 10:59:48 +010026#include "arm_compute/core/NEON/kernels/convolution/winograd/winograd_output_transform.hpp"
Pablo Tellobda6e4b2018-08-22 11:40:33 +010027#include "arm_compute/core/NEON/kernels/convolution/common/arm.hpp"
28
Pablo Tellod3d97d22018-10-05 10:59:48 +010029namespace
Pablo Tellobda6e4b2018-08-22 11:40:33 +010030{
31
Pablo Tellod3d97d22018-10-05 10:59:48 +010032template <bool Specialized, int PadRight=0>
33void winograd_output_transform_6_3_fp32_process_tile(
Pablo Tellobda6e4b2018-08-22 11:40:33 +010034 const int n_channels,
35 const float* const matrix_base,
36 const int matrix_stride,
37 const float* const biases,
38 float* const output,
39 const int output_row_stride,
Pablo Tellod3d97d22018-10-05 10:59:48 +010040 const int output_col_stride,
41 const int _pad_bottom,
42 const int _pad_right
Pablo Tellobda6e4b2018-08-22 11:40:33 +010043)
44{
45 (void) output_row_stride;
Pablo Tellod3d97d22018-10-05 10:59:48 +010046 (void) _pad_bottom;
47 constexpr int output_tile_cols = 6;
48 constexpr int inner_tile_cols = 8;
49
50 const int pad_right = Specialized ? PadRight : _pad_right;
51 const int cells_j = output_tile_cols - pad_right;
Pablo Tellobda6e4b2018-08-22 11:40:33 +010052
53 // Construct a map to the output cells
54 float *outptrs[cells_j];
55 for (int j = 0; j < cells_j; j++)
56 {
57 outptrs[j] = output + j*output_col_stride;
58 }
59 const float *inptr = matrix_base;
60 const float *bptr = biases;
61
62 // For each channel of the output
63 int channels_remaining = n_channels;
64#ifdef __arm_any__
65 for (; channels_remaining >= 4; channels_remaining -= 4)
66 {
67 // Matrices used and computed during this transform
68 float32x4_t F[inner_tile_cols], f[output_tile_cols], b = vdupq_n_f32(0.0f);
69
70 // Read a 1x8 tile in the Winograd domain
71 for (int j = 0; j < inner_tile_cols; j++)
72 {
73 F[j] = vld1q_f32(inptr + j*matrix_stride);
74 }
75 inptr += 4;
76
77 f[0] = vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmulq_n_f32(F[6], 1), F[5], 1), F[4], 1), F[3], 1), F[2], 1), F[1], 1), F[0], 1);
78 f[1] = vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmulq_n_f32(F[2], 1), F[6], 3), F[4], 2), F[3], -2), F[5], -3), F[1], -1);
79 f[2] = vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmulq_n_f32(F[2], 1), F[1], 1), F[6], 9), F[5], 9), F[4], 4), F[3], 4);
80 f[3] = vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmulq_n_f32(F[2], 1), F[6], 27), F[4], 8), F[3], -8), F[5], -27), F[1], -1);
81 f[4] = vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmulq_n_f32(F[2], 1), F[1], 1), F[6], 81), F[5], 81), F[4], 16), F[3], 16);
82 f[5] = vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmlaq_n_f32(vmulq_n_f32(F[7], 1), F[2], 1), F[6], 243), F[4], 32), F[3], -32), F[5], -243), F[1], -1);
83
84 // Write out the output tile
85 if (bptr != 0)
86 {
87 b = vld1q_f32(bptr);
88 bptr += 4;
89 }
90 for (int j = 0; j < cells_j; j++)
91 {
92 vst1q_f32(outptrs[j], f[j] + b);
93 outptrs[j] += 4;
94 }
95 }
96 for (; channels_remaining >= 2; channels_remaining -= 2)
97 {
98 // Matrices used and computed during this transform
99 float32x2_t F[inner_tile_cols], f[output_tile_cols], b = vdup_n_f32(0.0f);
100
101 // Read a 1x8 tile in the Winograd domain
102 for (int j = 0; j < inner_tile_cols; j++)
103 {
104 F[j] = vld1_f32(inptr + j*matrix_stride);
105 }
106 inptr += 2;
107
108 f[0] = vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmul_n_f32(F[6], 1), F[5], 1), F[4], 1), F[3], 1), F[2], 1), F[1], 1), F[0], 1);
109 f[1] = vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmul_n_f32(F[2], 1), F[6], 3), F[4], 2), F[3], -2), F[5], -3), F[1], -1);
110 f[2] = vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmul_n_f32(F[2], 1), F[1], 1), F[6], 9), F[5], 9), F[4], 4), F[3], 4);
111 f[3] = vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmul_n_f32(F[2], 1), F[6], 27), F[4], 8), F[3], -8), F[5], -27), F[1], -1);
112 f[4] = vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmul_n_f32(F[2], 1), F[1], 1), F[6], 81), F[5], 81), F[4], 16), F[3], 16);
113 f[5] = vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmla_n_f32(vmul_n_f32(F[7], 1), F[2], 1), F[6], 243), F[4], 32), F[3], -32), F[5], -243), F[1], -1);
114
115 // Write out the output tile
116 if (bptr != 0)
117 {
118 b = vld1_f32(bptr);
119 bptr += 2;
120 }
121 for (int j = 0; j < cells_j; j++)
122 {
123 vst1_f32(outptrs[j], f[j] + b);
124 outptrs[j] += 2;
125 }
126 }
127#endif // __arm_any__
128 for (; channels_remaining; channels_remaining--)
129 {
130 // Matrices used and computed during this transform
131 float F[inner_tile_cols], f[output_tile_cols], b = 0.0f;
132
133 // Read a 1x8 tile in the Winograd domain
134 for (int j = 0; j < inner_tile_cols; j++)
135 {
136 F[j] = *(inptr + j*matrix_stride);
137 }
138 inptr++;
139
140 f[0] = F[0]*1 + F[1]*1 + F[2]*1 + F[3]*1 + F[4]*1 + F[5]*1 + F[6]*1;
141 f[1] = F[1]*-1 + F[5]*-3 + F[3]*-2 + F[4]*2 + F[6]*3 + F[2]*1;
142 f[2] = F[3]*4 + F[4]*4 + F[5]*9 + F[6]*9 + F[1]*1 + F[2]*1;
143 f[3] = F[1]*-1 + F[5]*-27 + F[3]*-8 + F[4]*8 + F[6]*27 + F[2]*1;
144 f[4] = F[3]*16 + F[4]*16 + F[5]*81 + F[6]*81 + F[1]*1 + F[2]*1;
145 f[5] = F[1]*-1 + F[5]*-243 + F[3]*-32 + F[4]*32 + F[6]*243 + F[2]*1 + F[7]*1;
146
147 // Write out the output tile
148 if (bptr != 0)
149 {
150 b = *(bptr++);
151 }
152 for (int j = 0; j < cells_j; j++)
153 {
154 *(outptrs[j]++) = f[j] + b;
155 }
156 }
157}
158
Pablo Tellod3d97d22018-10-05 10:59:48 +0100159} // namespace (anonymous)
160
161namespace winograd
Pablo Tellobda6e4b2018-08-22 11:40:33 +0100162{
Pablo Tellod3d97d22018-10-05 10:59:48 +0100163using Tiles = OutputTransformImplTiles<1, 3, 1, 8, float>;
164
165template <>
166const Tiles::TileFn Tiles::tilefn_unpadded = winograd_output_transform_6_3_fp32_process_tile<true>;
167
168template <>
169const Tiles::TileFn Tiles::tilefn_right_padded[n_pad_right] = {
170 winograd_output_transform_6_3_fp32_process_tile<true, 1>,
171 winograd_output_transform_6_3_fp32_process_tile<true, 2>,
172 winograd_output_transform_6_3_fp32_process_tile<true, 3>,
173 winograd_output_transform_6_3_fp32_process_tile<true, 4>,
174 winograd_output_transform_6_3_fp32_process_tile<true, 5>,
Pablo Tellobda6e4b2018-08-22 11:40:33 +0100175};
176
Pablo Tellod3d97d22018-10-05 10:59:48 +0100177template class OutputTransform<1, 3, 1, 8, float>;
178template class OutputTransform<3, 1, 8, 1, float>;
Pablo Tellobda6e4b2018-08-22 11:40:33 +0100179} // namespace winograd