blob: 72c43019faa0a75a0422ca6fc6e329a0e83c4c7b [file] [log] [blame]
Pablo Tello8f43d742019-03-27 09:28:32 +00001/*
Pablo Marquez Tello7976f082024-02-13 13:56:15 +00002 * Copyright (c) 2022, 2024 Arm Limited.
Pablo Tello8f43d742019-03-27 09:28:32 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
ramelg01a1f78512022-06-29 16:28:10 +010025#include <algorithm>
26#include <cstddef>
27#include <arm_neon.h>
Pablo Tello8f43d742019-03-27 09:28:32 +000028
ramelg01a1f78512022-06-29 16:28:10 +010029namespace arm_conv {
30namespace winograd {
31namespace output_transform {
Pablo Tello8f43d742019-03-27 09:28:32 +000032
ramelg01a1f78512022-06-29 16:28:10 +010033void arm_fp32_4x4_3x3(
34 unsigned int n_channels,
Pablo Tello8f43d742019-03-27 09:28:32 +000035 const float* inptr,
Pablo Marquez Tello7976f082024-02-13 13:56:15 +000036 size_t matrix_stride,
Pablo Tello8f43d742019-03-27 09:28:32 +000037 const float* bptr,
ramelg01a1f78512022-06-29 16:28:10 +010038 float *outptr,
Pablo Marquez Tello7976f082024-02-13 13:56:15 +000039 size_t output_row_stride,
40 size_t output_col_stride,
41 float output_min,
42 float output_max
Pablo Tello8f43d742019-03-27 09:28:32 +000043)
44{
ramelg01a1f78512022-06-29 16:28:10 +010045 constexpr auto output_tile_rows = 4u, output_tile_cols = 4u;
Pablo Tello8f43d742019-03-27 09:28:32 +000046
47 // For each channel of the output
ramelg01a1f78512022-06-29 16:28:10 +010048 for (; n_channels >= 4; n_channels -= 4)
Pablo Tello5264b7d2019-10-21 14:25:41 +010049 {
50 // Matrices used and computed during this transform
51 float32x4_t F[6][6], FZ[6][4], f[4][4], b;
52
53 // Read a 6x6 tile in the Winograd domain
ramelg01a1f78512022-06-29 16:28:10 +010054 for (auto i = 0u, m = 0u; i < 6; i++)
Pablo Tello5264b7d2019-10-21 14:25:41 +010055 {
ramelg01a1f78512022-06-29 16:28:10 +010056 for (auto j = 0u; j < 6; j++, m++)
Pablo Tello5264b7d2019-10-21 14:25:41 +010057 {
58 F[i][j] = vld1q_f32(inptr + m*matrix_stride);
59 }
60 }
61 inptr += 4;
62
63 // Compute the matrix F Z
ramelg01a1f78512022-06-29 16:28:10 +010064 for (auto i = 0u; i < 6; i++)
Pablo Tello5264b7d2019-10-21 14:25:41 +010065 {
66 // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
67 FZ[i][0] = vaddq_f32(vaddq_f32(vaddq_f32(F[i][0], F[i][1]), vaddq_f32(F[i][2], F[i][3])), F[i][4]);
68
69 // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4];
70 FZ[i][1] = vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 2.0f);
71
72 // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4];
73 FZ[i][2] = vmlaq_n_f32(vaddq_f32(F[i][1], F[i][2]), vaddq_f32(F[i][3], F[i][4]), 4.0f);
74
75 // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5];
76 FZ[i][3] = vaddq_f32(vmlaq_n_f32(vsubq_f32(F[i][1], F[i][2]), vsubq_f32(F[i][3], F[i][4]), 8.0f), F[i][5]);
77 }
78
79 // Compute the output tile f = ZT F Z
ramelg01a1f78512022-06-29 16:28:10 +010080 for (auto j = 0u; j < 4; j++)
Pablo Tello5264b7d2019-10-21 14:25:41 +010081 {
82 // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
83 f[0][j] = vaddq_f32(vaddq_f32(vaddq_f32(FZ[0][j], FZ[1][j]), vaddq_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
84
85 // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j];
86 f[1][j] = vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 2.0f);
87
88 // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j];
89 f[2][j] = vmlaq_n_f32(vaddq_f32(FZ[1][j], FZ[2][j]), vaddq_f32(FZ[3][j], FZ[4][j]), 4.0f);
90
91 // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j];
92 f[3][j] = vaddq_f32(vmlaq_n_f32(vsubq_f32(FZ[1][j], FZ[2][j]), vsubq_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]);
93 }
94
95 // Write out the output tile
96 if (bptr != nullptr)
97 {
98 b = vld1q_f32(bptr);
99 bptr += 4;
100 }
101 else
102 {
103 b = vdupq_n_f32(0.0f);
104 }
ramelg01a1f78512022-06-29 16:28:10 +0100105 for (auto i = 0u; i < output_tile_rows; i++)
Pablo Tello5264b7d2019-10-21 14:25:41 +0100106 {
ramelg01a1f78512022-06-29 16:28:10 +0100107 for (auto j = 0u; j < output_tile_cols; j++)
Pablo Tello5264b7d2019-10-21 14:25:41 +0100108 {
109 const auto y =
110 vmaxq_f32(vminq_f32(vaddq_f32(f[i][j], b), vdupq_n_f32(output_max)),
111 vdupq_n_f32(output_min));
ramelg01a1f78512022-06-29 16:28:10 +0100112 vst1q_f32(outptr + i*output_row_stride + j*output_col_stride, y);
Pablo Tello5264b7d2019-10-21 14:25:41 +0100113 }
114 }
ramelg01a1f78512022-06-29 16:28:10 +0100115 outptr += 4;
Pablo Tello5264b7d2019-10-21 14:25:41 +0100116 }
ramelg01a1f78512022-06-29 16:28:10 +0100117 for (; n_channels >= 2; n_channels -= 2)
Pablo Tello8f43d742019-03-27 09:28:32 +0000118 {
119 // Matrices used and computed during this transform
120 float32x2_t F[6][6], FZ[6][4], f[4][4], b;
121
122 // Read a 6x6 tile in the Winograd domain
ramelg01a1f78512022-06-29 16:28:10 +0100123 for (auto i = 0u, m = 0u; i < 6; i++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000124 {
ramelg01a1f78512022-06-29 16:28:10 +0100125 for (auto j = 0u; j < 6; j++, m++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000126 {
127 F[i][j] = vld1_f32(inptr + m*matrix_stride);
128 }
129 }
130 inptr += 2;
131
132 // Compute the matrix F Z
ramelg01a1f78512022-06-29 16:28:10 +0100133 for (auto i = 0u; i < 6; i++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000134 {
135 // FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
136 FZ[i][0] = vadd_f32(vadd_f32(vadd_f32(F[i][0], F[i][1]), vadd_f32(F[i][2], F[i][3])), F[i][4]);
137
138 // FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4];
139 FZ[i][1] = vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 2.0f);
140
141 // FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4];
142 FZ[i][2] = vmla_n_f32(vadd_f32(F[i][1], F[i][2]), vadd_f32(F[i][3], F[i][4]), 4.0f);
143
144 // FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5];
145 FZ[i][3] = vadd_f32(vmla_n_f32(vsub_f32(F[i][1], F[i][2]), vsub_f32(F[i][3], F[i][4]), 8.0f), F[i][5]);
146 }
147
148 // Compute the output tile f = ZT F Z
ramelg01a1f78512022-06-29 16:28:10 +0100149 for (auto j = 0u; j < 4; j++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000150 {
151 // f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
152 f[0][j] = vadd_f32(vadd_f32(vadd_f32(FZ[0][j], FZ[1][j]), vadd_f32(FZ[2][j], FZ[3][j])), FZ[4][j]);
153
154 // f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j];
155 f[1][j] = vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 2.0f);
156
157 // f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j];
158 f[2][j] = vmla_n_f32(vadd_f32(FZ[1][j], FZ[2][j]), vadd_f32(FZ[3][j], FZ[4][j]), 4.0f);
159
160 // f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j];
161 f[3][j] = vadd_f32(vmla_n_f32(vsub_f32(FZ[1][j], FZ[2][j]), vsub_f32(FZ[3][j], FZ[4][j]), 8.0f), FZ[5][j]);
162 }
163
164 // Write out the output tile
165 if (bptr != nullptr)
166 {
167 b = vld1_f32(bptr);
168 bptr += 2;
169 }
170 else
171 {
172 b = vdup_n_f32(0.0f);
173 }
ramelg01a1f78512022-06-29 16:28:10 +0100174 for (auto i = 0u; i < output_tile_rows; i++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000175 {
ramelg01a1f78512022-06-29 16:28:10 +0100176 for (auto j = 0u; j < output_tile_cols; j++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000177 {
Pablo Tello5264b7d2019-10-21 14:25:41 +0100178 const auto y =
179 vmax_f32(vmin_f32(vadd_f32(f[i][j], b), vdup_n_f32(output_max)),
180 vdup_n_f32(output_min));
ramelg01a1f78512022-06-29 16:28:10 +0100181 vst1_f32(outptr + i*output_row_stride + j*output_col_stride, y);
Pablo Tello8f43d742019-03-27 09:28:32 +0000182 }
183 }
ramelg01a1f78512022-06-29 16:28:10 +0100184 outptr += 2;
Pablo Tello8f43d742019-03-27 09:28:32 +0000185 }
ramelg01a1f78512022-06-29 16:28:10 +0100186 for (; n_channels; n_channels--)
Pablo Tello8f43d742019-03-27 09:28:32 +0000187 {
188 // Matrices used and computed during this transform
189 float F[6][6], FZ[6][4], f[4][4], b;
190
191 // Read a 6x6 tile in the Winograd domain
ramelg01a1f78512022-06-29 16:28:10 +0100192 for (auto i = 0u, m = 0u; i < 6; i++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000193 {
ramelg01a1f78512022-06-29 16:28:10 +0100194 for (auto j = 0u; j < 6; j++, m++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000195 {
196 F[i][j] = *(inptr + m*matrix_stride);
197 }
198 }
199 inptr++;
200
201 // Compute the matrix F Z
ramelg01a1f78512022-06-29 16:28:10 +0100202 for (auto i = 0u; i < 6; i++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000203 {
204 FZ[i][0] = 1*F[i][0] + 1*F[i][1] + 1*F[i][2] + 1*F[i][3] + 1*F[i][4];
205 FZ[i][1] = 1*F[i][1] + -1*F[i][2] + 2*F[i][3] + -2*F[i][4];
206 FZ[i][2] = 1*F[i][1] + 1*F[i][2] + 4*F[i][3] + 4*F[i][4];
207 FZ[i][3] = 1*F[i][1] + -1*F[i][2] + 8*F[i][3] + -8*F[i][4] + 1*F[i][5];
208 }
209
210 // Compute the output tile f = ZT F Z
ramelg01a1f78512022-06-29 16:28:10 +0100211 for (auto j = 0u; j < 4; j++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000212 {
213 f[0][j] = 1*FZ[0][j] + 1*FZ[1][j] + 1*FZ[2][j] + 1*FZ[3][j] + 1*FZ[4][j];
214 f[1][j] = 1*FZ[1][j] + -1*FZ[2][j] + 2*FZ[3][j] + -2*FZ[4][j];
215 f[2][j] = 1*FZ[1][j] + 1*FZ[2][j] + 4*FZ[3][j] + 4*FZ[4][j];
216 f[3][j] = 1*FZ[1][j] + -1*FZ[2][j] + 8*FZ[3][j] + -8*FZ[4][j] + 1*FZ[5][j];
217 }
218
219 // Write out the output tile
220 if (bptr != nullptr)
221 {
222 b = *(bptr++);
223 }
224 else
225 {
226 b = 0.0f;
227 }
ramelg01a1f78512022-06-29 16:28:10 +0100228 for (auto i = 0u; i < output_tile_rows; i++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000229 {
ramelg01a1f78512022-06-29 16:28:10 +0100230 for (auto j = 0u; j < output_tile_cols; j++)
Pablo Tello8f43d742019-03-27 09:28:32 +0000231 {
Pablo Tello5264b7d2019-10-21 14:25:41 +0100232 const auto y = std::max(std::min(f[i][j] + b, output_max), output_min);
ramelg01a1f78512022-06-29 16:28:10 +0100233 *(outptr + i*output_row_stride + j*output_col_stride) = y;
Pablo Tello8f43d742019-03-27 09:28:32 +0000234 }
235 }
ramelg01a1f78512022-06-29 16:28:10 +0100236 outptr++;
Pablo Tello8f43d742019-03-27 09:28:32 +0000237 }
238}
239
ramelg01a1f78512022-06-29 16:28:10 +0100240} // namespace output_transform
Pablo Tello8f43d742019-03-27 09:28:32 +0000241} // namespace winograd
ramelg01a1f78512022-06-29 16:28:10 +0100242} // namespace arm_conv