blob: b2645907913ae3151a56a98cbff3d52be567d031 [file] [log] [blame]
giuros0114c4e0f2019-03-26 17:44:40 +00001/*
2 * Copyright (c) 2019 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h"
25
26#include "arm_compute/core/ITensor.h"
27#include "arm_compute/core/NEON/wrapper/traits.h"
28#include "arm_compute/core/NEON/wrapper/wrapper.h"
29#include "arm_compute/core/TensorInfo.h"
30#include "arm_compute/core/Types.h"
31#include "arm_compute/core/Utils.h"
32#include "arm_compute/core/Window.h"
33
34#include <arm_neon.h>
35#include <cmath>
36#include <complex>
37
38namespace arm_compute
39{
40namespace
41{
42constexpr float PI = 3.141592653589793f;
43
44float32x2_t c_mul_neon(float32x2_t a, float32x2_t b)
45{
46 float32x2_t tmp = wrapper::vmul(a, b);
47
48 const float P1 = wrapper::vgetlane(tmp, 0);
49 const float P2 = wrapper::vgetlane(tmp, 1);
50
51 const float a_r = wrapper::vgetlane(a, 0);
52 const float a_i = wrapper::vgetlane(a, 1);
53 const float b_r = wrapper::vgetlane(b, 0);
54 const float b_i = wrapper::vgetlane(b, 1);
55
56 const float P3 = (a_r + a_i) * (b_r + b_i);
57 float32x2_t out = { P1 - P2, P3 - P2 - P1 };
58 return out;
59}
60
61float32x2_t c_mul_neon_img(float32x2_t a, float img_constant)
62{
63 const float a_r = wrapper::vgetlane(a, 0);
64 const float a_i = wrapper::vgetlane(a, 1);
65
66 const auto out = wrapper::vmul(float32x2_t{ -a_i, a_r }, float32x2_t{ img_constant, img_constant });
67 return out;
68}
69
70float32x2_t reduce_sum_5(float32x2_t a, float32x2_t b, float32x2_t c, float32x2_t d, float32x2_t e)
71{
72 const auto t0 = wrapper::vadd(a, b);
73 const auto t1 = wrapper::vadd(c, d);
74 const auto t2 = wrapper::vadd(t0, t1);
75 return wrapper::vadd(t2, e);
76}
77
78float32x2_t reduce_sum_7(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7)
79{
80 const auto t0 = wrapper::vadd(x1, x2);
81 const auto t1 = wrapper::vadd(x3, x4);
82 const auto t2 = wrapper::vadd(x5, x6);
83 const auto t00 = wrapper::vadd(t0, t1);
84 const auto t01 = wrapper::vadd(t2, x7);
85
86 return wrapper::vadd(t00, t01);
87}
88
89float32x2_t reduce_sum_8(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7, float32x2_t x8)
90{
91 const auto t0 = wrapper::vadd(x1, x2);
92 const auto t1 = wrapper::vadd(x3, x4);
93 const auto t2 = wrapper::vadd(x5, x6);
94 const auto t3 = wrapper::vadd(x7, x8);
95 const auto t00 = wrapper::vadd(t0, t1);
96 const auto t01 = wrapper::vadd(t2, t3);
97
98 return wrapper::vadd(t00, t01);
99}
100
101void fft_2(float32x2_t &x, float32x2_t &y, float32x2_t &w)
102{
103 float32x2_t a = x;
104 float32x2_t b = c_mul_neon(w, y);
105
106 x = wrapper::vadd(a, b);
107 y = wrapper::vsub(a, b);
108}
109
110constexpr float sqrt3div2 = 0.866025403784438;
111void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w, const float32x2_t &w2)
112{
113 float32x2_t a = x;
114 float32x2_t b = c_mul_neon(w, y);
115 float32x2_t c = c_mul_neon(w2, z);
116
117 x = wrapper::vadd(a, b);
118 x = wrapper::vadd(x, c);
119
120 const auto v1 = wrapper::vmul(float32x2_t{ 0.5f, 0.5 }, wrapper::vadd(b, c));
121 const auto v2 = c_mul_neon(float32x2_t{ 0.f, -sqrt3div2 }, wrapper::vsub(b, c));
122
123 y = z = wrapper::vsub(a, v1);
124 y = wrapper::vadd(y, v2);
125 z = wrapper::vsub(z, v2);
126}
127
128void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3)
129{
130 float32x2_t a = x1;
131 float32x2_t b = c_mul_neon(w, x2);
132 float32x2_t c = c_mul_neon(w2, x3);
133 float32x2_t d = c_mul_neon(w3, x4);
134
135 const auto x11 = wrapper::vadd(a, b);
136 const auto x12 = wrapper::vadd(c, d);
137 x1 = wrapper::vadd(x11, x12);
138
139 const auto x21 = wrapper::vadd(a, c_mul_neon_img(b, -1));
140 const auto x22 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, 1.f));
141 x2 = wrapper::vadd(x21, x22);
142
143 const auto x31 = wrapper::vadd(a, wrapper::vneg(b));
144 const auto x32 = wrapper::vadd(c, wrapper::vneg(d));
145 x3 = wrapper::vadd(x31, x32);
146
147 const auto x41 = wrapper::vadd(a, c_mul_neon_img(b, 1));
148 const auto x42 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, -1));
149 x4 = wrapper::vadd(x41, x42);
150}
151
152constexpr float W5_0 = 0.30901699437494f;
153constexpr float W5_1 = 0.95105651629515f;
154constexpr float W5_2 = 0.80901699437494f;
155constexpr float W5_3 = 0.58778525229247f;
156void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4)
157{
158 const auto a = x1;
159 const auto b = c_mul_neon(w, x2);
160 const auto c = c_mul_neon(w2, x3);
161 const auto d = c_mul_neon(w3, x4);
162 const auto e = c_mul_neon(w4, x5);
163
164 const auto b0 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, b);
165 const auto b1 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, b);
166 const auto b2 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, b);
167 const auto b3 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, b);
168
169 const auto c0 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, c);
170 const auto c1 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, c);
171 const auto c2 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, c);
172 const auto c3 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, c);
173
174 const auto d0 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, d);
175 const auto d1 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, d);
176 const auto d2 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, d);
177 const auto d3 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, d);
178
179 const auto e0 = c_mul_neon(float32x2_t{ W5_0, W5_1 }, e);
180 const auto e1 = c_mul_neon(float32x2_t{ -W5_2, W5_3 }, e);
181 const auto e2 = c_mul_neon(float32x2_t{ -W5_2, -W5_3 }, e);
182 const auto e3 = c_mul_neon(float32x2_t{ W5_0, -W5_1 }, e);
183
184 x1 = reduce_sum_5(a, b, c, d, e);
185 x2 = reduce_sum_5(a, b0, c0, d0, e0);
186 x3 = reduce_sum_5(a, b1, c1, d1, e1);
187 x4 = reduce_sum_5(a, b2, c2, d2, e2);
188 x5 = reduce_sum_5(a, b3, c3, d3, e3);
189}
190
191constexpr float W7_0 = 0.62348980185873f;
192constexpr float W7_1 = 0.78183148246802f;
193constexpr float W7_2 = 0.22252093395631f;
194constexpr float W7_3 = 0.97492791218182f;
195constexpr float W7_4 = 0.90096886790241f;
196constexpr float W7_5 = 0.43388373911755f;
197void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3,
198 const float32x2_t &w4,
199 const float32x2_t &w5, const float32x2_t &w6)
200{
201 const auto a = x1;
202 const auto b = c_mul_neon(w, x2);
203 const auto c = c_mul_neon(w2, x3);
204 const auto d = c_mul_neon(w3, x4);
205 const auto e = c_mul_neon(w4, x5);
206 const auto f = c_mul_neon(w5, x6);
207 const auto g = c_mul_neon(w6, x7);
208
209 const auto b0 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, b);
210 const auto b1 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, b);
211 const auto b2 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, b);
212 const auto b3 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, b);
213 const auto b4 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, b);
214 const auto b5 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, b);
215
216 const auto c0 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, c);
217 const auto c1 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, c);
218 const auto c2 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, c);
219 const auto c3 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, c);
220 const auto c4 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, c);
221 const auto c5 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, c);
222
223 const auto d0 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, d);
224 const auto d1 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, d);
225 const auto d2 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, d);
226 const auto d3 = c_mul_neon(float32x2_t{ -W7_2, +W7_3 }, d);
227 const auto d4 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, d);
228 const auto d5 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, d);
229
230 const auto e0 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, e);
231 const auto e1 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, e);
232 const auto e2 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, e);
233 const auto e3 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, e);
234 const auto e4 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, e);
235 const auto e5 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, e);
236
237 const auto f0 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, f);
238 const auto f1 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, f);
239 const auto f2 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, f);
240 const auto f3 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, f);
241 const auto f4 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, f);
242 const auto f5 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, f);
243
244 const auto g0 = c_mul_neon(float32x2_t{ W7_0, W7_1 }, g);
245 const auto g1 = c_mul_neon(float32x2_t{ -W7_2, W7_3 }, g);
246 const auto g2 = c_mul_neon(float32x2_t{ -W7_4, W7_5 }, g);
247 const auto g3 = c_mul_neon(float32x2_t{ -W7_4, -W7_5 }, g);
248 const auto g4 = c_mul_neon(float32x2_t{ -W7_2, -W7_3 }, g);
249 const auto g5 = c_mul_neon(float32x2_t{ W7_0, -W7_1 }, g);
250
251 x1 = reduce_sum_7(a, b, c, d, e, f, g);
252 x2 = reduce_sum_7(a, b0, c0, d0, e0, f0, g0);
253 x3 = reduce_sum_7(a, b1, c1, d1, e1, f1, g1);
254 x4 = reduce_sum_7(a, b2, c2, d2, e2, f2, g2);
255 x5 = reduce_sum_7(a, b3, c3, d3, e3, f3, g3);
256 x6 = reduce_sum_7(a, b4, c4, d4, e4, f4, g4);
257 x7 = reduce_sum_7(a, b5, c5, d5, e5, f5, g5);
258}
259
260constexpr float sqrt2div2 = 0.707106781186548;
261void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, float32x2_t &x8, const float32x2_t &w, const float32x2_t &w2,
262 const float32x2_t &w3,
263 const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6,
264 const float32x2_t &w7)
265{
266 const auto a = x1;
267 const auto b = c_mul_neon(w, x2);
268 const auto c = c_mul_neon(w2, x3);
269 const auto d = c_mul_neon(w3, x4);
270 const auto e = c_mul_neon(w4, x5);
271 const auto f = c_mul_neon(w5, x6);
272 const auto g = c_mul_neon(w6, x7);
273 const auto h = c_mul_neon(w7, x8);
274
275 const auto b0 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, b);
276 const auto b1 = c_mul_neon(float32x2_t{ 0, -1 }, b);
277 const auto b2 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, b);
278 const auto b3 = c_mul_neon(float32x2_t{ -1, 0 }, b);
279 const auto b4 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, b);
280 const auto b5 = c_mul_neon(float32x2_t{ 0, 1 }, b);
281 const auto b6 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, b);
282
283 const auto c0 = c_mul_neon(float32x2_t{ 0, -1 }, c);
284 const auto c1 = c_mul_neon(float32x2_t{ -1, 0 }, c);
285 const auto c2 = c_mul_neon(float32x2_t{ 0, 1 }, c);
286 const auto c3 = c_mul_neon(float32x2_t{ 1, 0 }, c);
287 const auto c4 = c_mul_neon(float32x2_t{ 0, -1 }, c);
288 const auto c5 = c_mul_neon(float32x2_t{ -1, 0 }, c);
289 const auto c6 = c_mul_neon(float32x2_t{ 0, 1 }, c);
290
291 const auto d0 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, d);
292 const auto d1 = c_mul_neon(float32x2_t{ 0, 1 }, d);
293 const auto d2 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, d);
294 const auto d3 = c_mul_neon(float32x2_t{ -1, 0 }, d);
295 const auto d4 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, d);
296 const auto d5 = c_mul_neon(float32x2_t{ 0, -1 }, d);
297 const auto d6 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, d);
298
299 const auto e0 = c_mul_neon(float32x2_t{ -1, 0 }, e);
300 const auto e1 = c_mul_neon(float32x2_t{ 1, 0 }, e);
301 const auto e2 = c_mul_neon(float32x2_t{ -1, 0 }, e);
302 const auto e3 = c_mul_neon(float32x2_t{ 1, 0 }, e);
303 const auto e4 = c_mul_neon(float32x2_t{ -1, 0 }, e);
304 const auto e5 = c_mul_neon(float32x2_t{ 1, 0 }, e);
305 const auto e6 = c_mul_neon(float32x2_t{ -1, 0 }, e);
306
307 const auto f0 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, f);
308 const auto f1 = c_mul_neon(float32x2_t{ 0, -1 }, f);
309 const auto f2 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, f);
310 const auto f3 = c_mul_neon(float32x2_t{ -1, 0 }, f);
311 const auto f4 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, f);
312 const auto f5 = c_mul_neon(float32x2_t{ 0, 1 }, f);
313 const auto f6 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, f);
314
315 const auto g0 = c_mul_neon(float32x2_t{ 0, 1 }, g);
316 const auto g1 = c_mul_neon(float32x2_t{ -1, 0 }, g);
317 const auto g2 = c_mul_neon(float32x2_t{ 0, -1 }, g);
318 const auto g3 = c_mul_neon(float32x2_t{ 1, 0 }, g);
319 const auto g4 = c_mul_neon(float32x2_t{ 0, 1 }, g);
320 const auto g5 = c_mul_neon(float32x2_t{ -1, 0 }, g);
321 const auto g6 = c_mul_neon(float32x2_t{ 0, -1 }, g);
322
323 const auto h0 = c_mul_neon(float32x2_t{ sqrt2div2, sqrt2div2 }, h);
324 const auto h1 = c_mul_neon(float32x2_t{ 0, 1 }, h);
325 const auto h2 = c_mul_neon(float32x2_t{ -sqrt2div2, sqrt2div2 }, h);
326 const auto h3 = c_mul_neon(float32x2_t{ -1, 0 }, h);
327 const auto h4 = c_mul_neon(float32x2_t{ -sqrt2div2, -sqrt2div2 }, h);
328 const auto h5 = c_mul_neon(float32x2_t{ 0, -1 }, h);
329 const auto h6 = c_mul_neon(float32x2_t{ sqrt2div2, -sqrt2div2 }, h);
330
331 x1 = reduce_sum_8(a, b, c, d, e, f, g, h);
332 x2 = reduce_sum_8(a, b0, c0, d0, e0, f0, g0, h0);
333 x3 = reduce_sum_8(a, b1, c1, d1, e1, f1, g1, h1);
334 x4 = reduce_sum_8(a, b2, c2, d2, e2, f2, g2, h2);
335 x5 = reduce_sum_8(a, b3, c3, d3, e3, f3, g3, h3);
336 x6 = reduce_sum_8(a, b4, c4, d4, e4, f4, g4, h4);
337 x7 = reduce_sum_8(a, b5, c5, d5, e5, f5, g5, h5);
338 x8 = reduce_sum_8(a, b6, c6, d6, e6, f6, g6, h6);
339}
340
341template <bool first_stage>
342void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
343{
344 unsigned int Nx2 = 2 * Nx;
345 float alpha = 2 * PI / Nx2;
346
347 float32x2_t w{ 1, 0 };
348 const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
349
350 for(unsigned int j = 0; j < Nx; j++)
351 {
352 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx2)
353 {
354 auto a = float32x2_t{ 0, 0 };
355 auto b = float32x2_t{ 0, 0 };
356
357 // Load inputs
358 if(first_stage)
359 {
360 const auto ab = wrapper::vloadq(x + k);
361 a = wrapper::vgetlow(ab);
362 b = wrapper::vgethigh(ab);
363 }
364 else
365 {
366 a = wrapper::vload(x + k);
367 b = wrapper::vload(x + k + 2 * Nx);
368 }
369
370 // Base-case prime transform
371 fft_2(a, b, w);
372
373 // Write outputs
374 if(first_stage)
375 {
376 wrapper::vstore(X + k, wrapper::vcombine(a, b));
377 }
378 else
379 {
380 wrapper::vstore(X + k, a);
381 wrapper::vstore(X + k + 2 * Nx, b);
382 }
383 }
384
385 w = c_mul_neon(w, w_m);
386 }
387}
388
389template <bool first_stage>
390void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
391{
392 const unsigned int Nx3 = 3 * Nx;
393 const float alpha = 2 * PI / float(Nx3);
394 float32x2_t w{ 1, 0 };
395 const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
396
397 for(unsigned int j = 0; j < Nx; j++)
398 {
399 const auto w2 = c_mul_neon(w, w);
400
401 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx3)
402 {
403 // Load inputs
404 float32x2_t a = { 0, 0 };
405 float32x2_t b = { 0, 0 };
406 float32x2_t c = { 0, 0 };
407 if(first_stage)
408 {
409 const auto ab = wrapper::vloadq(x + k);
410 a = wrapper::vgetlow(ab);
411 b = wrapper::vgethigh(ab);
412 }
413 else
414 {
415 a = wrapper::vload(x + k);
416 b = wrapper::vload(x + k + 2 * Nx);
417 }
418 c = wrapper::vload(x + k + 4 * Nx);
419
420 // Base-case prime transform
421 fft_3(a, b, c, w, w2);
422
423 if(first_stage)
424 {
425 wrapper::vstore(X + k, wrapper::vcombine(a, b));
426 }
427 else
428 {
429 wrapper::vstore(X + k, a);
430 wrapper::vstore(X + k + 2 * Nx, b);
431 }
432 wrapper::vstore(X + k + 4 * Nx, c);
433 }
434 w = c_mul_neon(w, w_m);
435 }
436}
437
438template <bool first_stage>
439void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
440{
441 unsigned int Nx4 = 4 * Nx;
442 const float alpha = 2 * PI / float(Nx4);
443
444 float32x2_t w{ 1, 0 };
445 float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
446
447 for(unsigned int j = 0; j < Nx; j++)
448 {
449 const auto w2 = c_mul_neon(w, w);
450 const auto w3 = c_mul_neon(w2, w);
451
452 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx4)
453 {
454 float32x2_t a = { 0, 0 };
455 float32x2_t b = { 0, 0 };
456 float32x2_t c = { 0, 0 };
457 float32x2_t d = { 0, 0 };
458 if(first_stage)
459 {
460 const auto ab = wrapper::vloadq(x + k);
461 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
462 a = wrapper::vgetlow(ab);
463 b = wrapper::vgethigh(ab);
464 c = wrapper::vgetlow(cd);
465 d = wrapper::vgethigh(cd);
466 }
467 else
468 {
469 // Load inputs
470 a = wrapper::vload(x + k);
471 b = wrapper::vload(x + k + 2 * Nx);
472 c = wrapper::vload(x + k + 4 * Nx);
473 d = wrapper::vload(x + k + 6 * Nx);
474 }
475
476 // Base-case prime transform
477 fft_4(a, b, c, d, w, w2, w3);
478
479 if(first_stage)
480 {
481 wrapper::vstore(X + k, wrapper::vcombine(a, b));
482 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
483 }
484 else
485 {
486 wrapper::vstore(X + k, a);
487 wrapper::vstore(X + k + 2 * Nx, b);
488 wrapper::vstore(X + k + 4 * Nx, c);
489 wrapper::vstore(X + k + 6 * Nx, d);
490 }
491 }
492
493 w = c_mul_neon(w, w_m);
494 }
495}
496
497template <bool first_stage>
498void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
499{
500 unsigned int Nx5 = 5 * Nx;
501 const float alpha = 2 * PI / float(Nx5);
502
503 float32x2_t w{ 1, 0 };
504 float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
505
506 for(unsigned int j = 0; j < Nx; j++)
507 {
508 const float32x2_t w2 = c_mul_neon(w, w);
509 const float32x2_t w3 = c_mul_neon(w2, w);
510 const float32x2_t w4 = c_mul_neon(w3, w);
511
512 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx5)
513 {
514 float32x2_t a = { 0, 0 };
515 float32x2_t b = { 0, 0 };
516 float32x2_t c = { 0, 0 };
517 float32x2_t d = { 0, 0 };
518 float32x2_t e = { 0, 0 };
519
520 // Load inputs
521 if(first_stage)
522 {
523 const auto ab = wrapper::vloadq(x + k);
524 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
525
526 a = wrapper::vgetlow(ab);
527 b = wrapper::vgethigh(ab);
528 c = wrapper::vgetlow(cd);
529 d = wrapper::vgethigh(cd);
530 }
531 else
532 {
533 a = wrapper::vload(x + k);
534 b = wrapper::vload(x + k + 2 * Nx);
535 c = wrapper::vload(x + k + 4 * Nx);
536 d = wrapper::vload(x + k + 6 * Nx);
537 }
538 e = wrapper::vload(x + k + 8 * Nx);
539
540 // Base-case prime transform
541 fft_5(a, b, c, d, e, w, w2, w3, w4);
542
543 // Store outputs
544 if(first_stage)
545 {
546 wrapper::vstore(X + k, wrapper::vcombine(a, b));
547 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
548 }
549 else
550 {
551 wrapper::vstore(X + k, a);
552 wrapper::vstore(X + k + 2 * Nx, b);
553 wrapper::vstore(X + k + 4 * Nx, c);
554 wrapper::vstore(X + k + 6 * Nx, d);
555 }
556 wrapper::vstore(X + k + 8 * Nx, e);
557 }
558
559 w = c_mul_neon(w, w_m);
560 }
561}
562
563template <bool first_stage>
564void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
565{
566 unsigned int Nx7 = 7 * Nx;
567 const float alpha = 2 * PI / float(Nx7);
568
569 float32x2_t w{ 1, 0 };
570 float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
571
572 for(unsigned int j = 0; j < Nx; j++)
573 {
574 const float32x2_t w2 = c_mul_neon(w, w);
575 const float32x2_t w3 = c_mul_neon(w2, w);
576 const float32x2_t w4 = c_mul_neon(w3, w);
577 const float32x2_t w5 = c_mul_neon(w4, w);
578 const float32x2_t w6 = c_mul_neon(w5, w);
579
580 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx7)
581 {
582 float32x2_t a = { 0, 0 };
583 float32x2_t b = { 0, 0 };
584 float32x2_t c = { 0, 0 };
585 float32x2_t d = { 0, 0 };
586 float32x2_t e = { 0, 0 };
587 float32x2_t f = { 0, 0 };
588 float32x2_t g = { 0, 0 };
589
590 // Load inputs
591 if(first_stage)
592 {
593 const auto ab = wrapper::vloadq(x + k);
594 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
595 const auto ef = wrapper::vloadq(x + k + 8 * Nx);
596
597 a = wrapper::vgetlow(ab);
598 b = wrapper::vgethigh(ab);
599 c = wrapper::vgetlow(cd);
600 d = wrapper::vgethigh(cd);
601 e = wrapper::vgetlow(ef);
602 f = wrapper::vgethigh(ef);
603 }
604 else
605 {
606 a = wrapper::vload(x + k);
607 b = wrapper::vload(x + k + 2 * Nx);
608 c = wrapper::vload(x + k + 4 * Nx);
609 d = wrapper::vload(x + k + 6 * Nx);
610 e = wrapper::vload(x + k + 8 * Nx);
611 f = wrapper::vload(x + k + 10 * Nx);
612 }
613 g = wrapper::vload(x + k + 12 * Nx);
614
615 // Base-case prime transform
616 fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
617
618 if(first_stage)
619 {
620 wrapper::vstore(X + k, wrapper::vcombine(a, b));
621 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
622 wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
623 }
624 else
625 {
626 wrapper::vstore(X + k, a);
627 wrapper::vstore(X + k + 2 * Nx, b);
628 wrapper::vstore(X + k + 4 * Nx, c);
629 wrapper::vstore(X + k + 6 * Nx, d);
630 wrapper::vstore(X + k + 8 * Nx, e);
631 wrapper::vstore(X + k + 10 * Nx, f);
632 }
633 wrapper::vstore(X + k + 12 * Nx, g);
634 }
635
636 w = c_mul_neon(w, w_m);
637 }
638}
639
640template <bool first_stage>
641void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int N)
642{
643 unsigned int Nx8 = 8 * Nx;
644 const float alpha = 2 * PI / float(Nx8);
645
646 float32x2_t w{ 1, 0 };
647 const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
648
649 for(unsigned int j = 0; j < Nx; j++)
650 {
651 const float32x2_t w2 = c_mul_neon(w, w);
652 const float32x2_t w3 = c_mul_neon(w2, w);
653 const float32x2_t w4 = c_mul_neon(w3, w);
654 const float32x2_t w5 = c_mul_neon(w4, w);
655 const float32x2_t w6 = c_mul_neon(w5, w);
656 const float32x2_t w7 = c_mul_neon(w6, w);
657
658 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * Nx8)
659 {
660 // Load inputs
661 float32x2_t a = { 0, 0 };
662 float32x2_t b = { 0, 0 };
663 float32x2_t c = { 0, 0 };
664 float32x2_t d = { 0, 0 };
665 float32x2_t e = { 0, 0 };
666 float32x2_t f = { 0, 0 };
667 float32x2_t g = { 0, 0 };
668 float32x2_t h = { 0, 0 };
669
670 // Base-case prime transform
671 if(first_stage)
672 {
673 const auto ab = wrapper::vloadq(x + k);
674 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
675 const auto ef = wrapper::vloadq(x + k + 8 * Nx);
676 const auto gh = wrapper::vloadq(x + k + 12 * Nx);
677
678 a = wrapper::vgetlow(ab);
679 b = wrapper::vgethigh(ab);
680 c = wrapper::vgetlow(cd);
681 d = wrapper::vgethigh(cd);
682 e = wrapper::vgetlow(ef);
683 f = wrapper::vgethigh(ef);
684 g = wrapper::vgetlow(gh);
685 h = wrapper::vgethigh(gh);
686 }
687 else
688 {
689 a = wrapper::vload(x + k);
690 b = wrapper::vload(x + k + 2 * Nx);
691 c = wrapper::vload(x + k + 4 * Nx);
692 d = wrapper::vload(x + k + 6 * Nx);
693 e = wrapper::vload(x + k + 8 * Nx);
694 f = wrapper::vload(x + k + 10 * Nx);
695 g = wrapper::vload(x + k + 12 * Nx);
696 h = wrapper::vload(x + k + 14 * Nx);
697 }
698
699 // Apply twiddle factors
700 fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
701
702 // Store outputs
703 if(first_stage)
704 {
705 wrapper::vstore(X + k, wrapper::vcombine(a, b));
706 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
707 wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
708 wrapper::vstore(X + k + 12 * Nx, wrapper::vcombine(g, h));
709 }
710 else
711 {
712 wrapper::vstore(X + k, a);
713 wrapper::vstore(X + k + 2 * Nx, b);
714 wrapper::vstore(X + k + 4 * Nx, c);
715 wrapper::vstore(X + k + 6 * Nx, d);
716 wrapper::vstore(X + k + 8 * Nx, e);
717 wrapper::vstore(X + k + 10 * Nx, f);
718 wrapper::vstore(X + k + 12 * Nx, g);
719 wrapper::vstore(X + k + 14 * Nx, h);
720 }
721 }
722
723 w = c_mul_neon(w, w_m);
724 }
725}
726
727Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
728{
729 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
730 ARM_COMPUTE_RETURN_ERROR_ON(config.axis != 0);
731 ARM_COMPUTE_RETURN_ERROR_ON(NEFFTRadixStageKernel::supported_radix().count(config.radix) == 0);
732
733 // Checks performed when output is configured
734 if((output != nullptr) && (output->total_size() != 0))
735 {
736 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
737 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
738 }
739
740 return Status{};
741}
742
743std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
744{
745 if(output != nullptr)
746 {
747 auto_init_if_empty(*output, *input);
748 }
749
750 Window win = calculate_max_window(*input, Steps(config.radix));
751 if(output != nullptr)
752 {
753 output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
754 }
755
756 return std::make_pair(Status{}, win);
757}
758} // namespace
759
760NEFFTRadixStageKernel::NEFFTRadixStageKernel()
761 : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _func()
762{
763}
764
765template <bool first_stage>
766void NEFFTRadixStageKernel::set_radix_stage_fun(unsigned int radix)
767{
768 switch(radix)
769 {
770 case 2:
771 _func = &fft_radix_2_axes_0<first_stage>;
772 break;
773 case 3:
774 _func = &fft_radix_3_axes_0<first_stage>;
775 break;
776 case 4:
777 _func = &fft_radix_4_axes_0<first_stage>;
778 break;
779 case 5:
780 _func = &fft_radix_5_axes_0<first_stage>;
781 break;
782 case 7:
783 _func = &fft_radix_7_axes_0<first_stage>;
784 break;
785 case 8:
786 _func = &fft_radix_8_axes_0<first_stage>;
787 break;
788 default:
789 ARM_COMPUTE_ERROR("Radix not supported");
790 }
791}
792
793void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFTRadixStageKernelInfo &config)
794{
795 ARM_COMPUTE_ERROR_ON_NULLPTR(input);
796
797 // Output auto inizialitation if not yet initialized
798 if(output != nullptr)
799 {
800 auto_init_if_empty(*output->info(), *input->info()->clone());
801 }
802
803 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
804
805 _input = input;
806 _output = output;
807 _run_in_place = (output == nullptr) || (output == input);
808 _Nx = config.Nx;
809
810 if(config.is_first_stage)
811 {
812 set_radix_stage_fun<true>(config.radix);
813 }
814 else
815 {
816 set_radix_stage_fun<false>(config.radix);
817 }
818
819 // Configure kernel window
820 auto win_config = validate_and_configure_window(input->info(), (_run_in_place) ? nullptr : output->info(), config);
821 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
822 INEKernel::configure(win_config.second);
823}
824
825Status NEFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
826{
827 const bool run_in_place = (output == nullptr) || (output == input);
828 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config));
829 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
830 (run_in_place) ? nullptr : output->clone().get(),
831 config)
832 .first);
833
834 return Status{};
835}
836
837std::set<unsigned int> NEFFTRadixStageKernel::supported_radix()
838{
839 return std::set<unsigned int> { 2, 3, 4, 5, 7, 8 };
840}
841
842void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
843{
844 ARM_COMPUTE_UNUSED(info);
845 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
846 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
847
848 Window input_window = window;
849 input_window.set(Window::DimX, 0);
850
851 unsigned int N = _input->info()->dimension(0);
852
853 Iterator in(_input, input_window);
854 Iterator out(_run_in_place ? _input : _output, input_window);
855
856 execute_window_loop(input_window, [&](const Coordinates &)
857 {
858 _func(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, N);
859 },
860 in, out);
861
862 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
863 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
864}
865} // namespace arm_compute