blob: cb1391ab4e31c5e2371f7007335ad1ea84d6e4ae [file] [log] [blame]
giuros0114c4e0f2019-03-26 17:44:40 +00001/*
Georgios Pinitasddb93bb2020-10-02 16:38:59 +01002 * Copyright (c) 2019-2020 Arm Limited.
giuros0114c4e0f2019-03-26 17:44:40 +00003 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
Michalis Spyrouebcebf12020-10-21 00:04:14 +010024#include "src/core/NEON/kernels/NEFFTRadixStageKernel.h"
giuros0114c4e0f2019-03-26 17:44:40 +000025
26#include "arm_compute/core/ITensor.h"
giuros0114c4e0f2019-03-26 17:44:40 +000027#include "arm_compute/core/TensorInfo.h"
28#include "arm_compute/core/Types.h"
29#include "arm_compute/core/Utils.h"
30#include "arm_compute/core/Window.h"
Sang-Hoon Park68dd25f2020-10-19 16:00:11 +010031#include "src/core/NEON/wrapper/traits.h"
32#include "src/core/NEON/wrapper/wrapper.h"
33#include "src/core/helpers/AutoConfiguration.h"
34#include "src/core/helpers/WindowHelpers.h"
giuros0114c4e0f2019-03-26 17:44:40 +000035
36#include <arm_neon.h>
37#include <cmath>
38#include <complex>
giuros0105fb4482019-03-26 17:44:40 +000039#include <map>
40
giuros0114c4e0f2019-03-26 17:44:40 +000041namespace arm_compute
42{
43namespace
44{
giuros0105fb4482019-03-26 17:44:40 +000045// PI constant (from cmath)
46constexpr float kPi = float(M_PI);
47
48// Constant used in the fft_3 kernel
49constexpr float kSqrt3Div2 = 0.866025403784438;
50
51// Constants used in the fft_5 kernel
52constexpr float kW5_0 = 0.30901699437494f;
53constexpr float kW5_1 = 0.95105651629515f;
54constexpr float kW5_2 = 0.80901699437494f;
55constexpr float kW5_3 = 0.58778525229247f;
56
57// Constants used in the fft_7 kernel
58constexpr float kW7_0 = 0.62348980185873f;
59constexpr float kW7_1 = 0.78183148246802f;
60constexpr float kW7_2 = 0.22252093395631f;
61constexpr float kW7_3 = 0.97492791218182f;
62constexpr float kW7_4 = 0.90096886790241f;
63constexpr float kW7_5 = 0.43388373911755f;
64
65// Constant used in the fft_8 kernel
66constexpr float kSqrt2Div2 = 0.707106781186548;
giuros0114c4e0f2019-03-26 17:44:40 +000067
68float32x2_t c_mul_neon(float32x2_t a, float32x2_t b)
69{
giuros0105fb4482019-03-26 17:44:40 +000070 using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
giuros0114c4e0f2019-03-26 17:44:40 +000071
giuros0105fb4482019-03-26 17:44:40 +000072 const float32x2_t mask = { -1.0, 1.0 };
73 const float32x2_t tmp0 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
74 const float32x2_t tmp1 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
giuros0114c4e0f2019-03-26 17:44:40 +000075
giuros0105fb4482019-03-26 17:44:40 +000076 float32x2_t res = wrapper::vmul(tmp0, b);
giuros0114c4e0f2019-03-26 17:44:40 +000077
giuros0105fb4482019-03-26 17:44:40 +000078 b = wrapper::vrev64(b);
79 b = wrapper::vmul(b, mask);
80 res = wrapper::vmla(res, tmp1, b);
81
82 return res;
giuros0114c4e0f2019-03-26 17:44:40 +000083}
84
85float32x2_t c_mul_neon_img(float32x2_t a, float img_constant)
86{
87 const float a_r = wrapper::vgetlane(a, 0);
88 const float a_i = wrapper::vgetlane(a, 1);
89
90 const auto out = wrapper::vmul(float32x2_t{ -a_i, a_r }, float32x2_t{ img_constant, img_constant });
91 return out;
92}
93
94float32x2_t reduce_sum_5(float32x2_t a, float32x2_t b, float32x2_t c, float32x2_t d, float32x2_t e)
95{
96 const auto t0 = wrapper::vadd(a, b);
97 const auto t1 = wrapper::vadd(c, d);
98 const auto t2 = wrapper::vadd(t0, t1);
99 return wrapper::vadd(t2, e);
100}
101
102float32x2_t reduce_sum_7(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7)
103{
104 const auto t0 = wrapper::vadd(x1, x2);
105 const auto t1 = wrapper::vadd(x3, x4);
106 const auto t2 = wrapper::vadd(x5, x6);
107 const auto t00 = wrapper::vadd(t0, t1);
108 const auto t01 = wrapper::vadd(t2, x7);
109
110 return wrapper::vadd(t00, t01);
111}
112
113float32x2_t reduce_sum_8(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7, float32x2_t x8)
114{
115 const auto t0 = wrapper::vadd(x1, x2);
116 const auto t1 = wrapper::vadd(x3, x4);
117 const auto t2 = wrapper::vadd(x5, x6);
118 const auto t3 = wrapper::vadd(x7, x8);
119 const auto t00 = wrapper::vadd(t0, t1);
120 const auto t01 = wrapper::vadd(t2, t3);
121
122 return wrapper::vadd(t00, t01);
123}
124
125void fft_2(float32x2_t &x, float32x2_t &y, float32x2_t &w)
126{
127 float32x2_t a = x;
128 float32x2_t b = c_mul_neon(w, y);
129
130 x = wrapper::vadd(a, b);
131 y = wrapper::vsub(a, b);
132}
133
giuros0114c4e0f2019-03-26 17:44:40 +0000134void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w, const float32x2_t &w2)
135{
136 float32x2_t a = x;
137 float32x2_t b = c_mul_neon(w, y);
138 float32x2_t c = c_mul_neon(w2, z);
139
140 x = wrapper::vadd(a, b);
141 x = wrapper::vadd(x, c);
142
143 const auto v1 = wrapper::vmul(float32x2_t{ 0.5f, 0.5 }, wrapper::vadd(b, c));
giuros0105fb4482019-03-26 17:44:40 +0000144 const auto v2 = c_mul_neon(float32x2_t{ 0.f, -kSqrt3Div2 }, wrapper::vsub(b, c));
giuros0114c4e0f2019-03-26 17:44:40 +0000145
146 y = z = wrapper::vsub(a, v1);
147 y = wrapper::vadd(y, v2);
148 z = wrapper::vsub(z, v2);
149}
150
151void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3)
152{
153 float32x2_t a = x1;
154 float32x2_t b = c_mul_neon(w, x2);
155 float32x2_t c = c_mul_neon(w2, x3);
156 float32x2_t d = c_mul_neon(w3, x4);
157
158 const auto x11 = wrapper::vadd(a, b);
159 const auto x12 = wrapper::vadd(c, d);
160 x1 = wrapper::vadd(x11, x12);
161
162 const auto x21 = wrapper::vadd(a, c_mul_neon_img(b, -1));
163 const auto x22 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, 1.f));
164 x2 = wrapper::vadd(x21, x22);
165
166 const auto x31 = wrapper::vadd(a, wrapper::vneg(b));
167 const auto x32 = wrapper::vadd(c, wrapper::vneg(d));
168 x3 = wrapper::vadd(x31, x32);
169
170 const auto x41 = wrapper::vadd(a, c_mul_neon_img(b, 1));
171 const auto x42 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, -1));
172 x4 = wrapper::vadd(x41, x42);
173}
174
giuros0114c4e0f2019-03-26 17:44:40 +0000175void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4)
176{
177 const auto a = x1;
178 const auto b = c_mul_neon(w, x2);
179 const auto c = c_mul_neon(w2, x3);
180 const auto d = c_mul_neon(w3, x4);
181 const auto e = c_mul_neon(w4, x5);
182
giuros0105fb4482019-03-26 17:44:40 +0000183 const auto b0 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, b);
184 const auto b1 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, b);
185 const auto b2 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, b);
186 const auto b3 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000187
giuros0105fb4482019-03-26 17:44:40 +0000188 const auto c0 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, c);
189 const auto c1 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, c);
190 const auto c2 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, c);
191 const auto c3 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, c);
giuros0114c4e0f2019-03-26 17:44:40 +0000192
giuros0105fb4482019-03-26 17:44:40 +0000193 const auto d0 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, d);
194 const auto d1 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, d);
195 const auto d2 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, d);
196 const auto d3 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000197
giuros0105fb4482019-03-26 17:44:40 +0000198 const auto e0 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, e);
199 const auto e1 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, e);
200 const auto e2 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, e);
201 const auto e3 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, e);
giuros0114c4e0f2019-03-26 17:44:40 +0000202
203 x1 = reduce_sum_5(a, b, c, d, e);
204 x2 = reduce_sum_5(a, b0, c0, d0, e0);
205 x3 = reduce_sum_5(a, b1, c1, d1, e1);
206 x4 = reduce_sum_5(a, b2, c2, d2, e2);
207 x5 = reduce_sum_5(a, b3, c3, d3, e3);
208}
209
giuros0114c4e0f2019-03-26 17:44:40 +0000210void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3,
211 const float32x2_t &w4,
212 const float32x2_t &w5, const float32x2_t &w6)
213{
214 const auto a = x1;
215 const auto b = c_mul_neon(w, x2);
216 const auto c = c_mul_neon(w2, x3);
217 const auto d = c_mul_neon(w3, x4);
218 const auto e = c_mul_neon(w4, x5);
219 const auto f = c_mul_neon(w5, x6);
220 const auto g = c_mul_neon(w6, x7);
221
giuros0105fb4482019-03-26 17:44:40 +0000222 const auto b0 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, b);
223 const auto b1 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, b);
224 const auto b2 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, b);
225 const auto b3 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, b);
226 const auto b4 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, b);
227 const auto b5 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000228
giuros0105fb4482019-03-26 17:44:40 +0000229 const auto c0 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, c);
230 const auto c1 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, c);
231 const auto c2 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, c);
232 const auto c3 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, c);
233 const auto c4 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, c);
234 const auto c5 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, c);
giuros0114c4e0f2019-03-26 17:44:40 +0000235
giuros0105fb4482019-03-26 17:44:40 +0000236 const auto d0 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, d);
237 const auto d1 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, d);
238 const auto d2 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, d);
239 const auto d3 = c_mul_neon(float32x2_t{ -kW7_2, +kW7_3 }, d);
240 const auto d4 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, d);
241 const auto d5 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000242
giuros0105fb4482019-03-26 17:44:40 +0000243 const auto e0 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, e);
244 const auto e1 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, e);
245 const auto e2 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, e);
246 const auto e3 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, e);
247 const auto e4 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, e);
248 const auto e5 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, e);
giuros0114c4e0f2019-03-26 17:44:40 +0000249
giuros0105fb4482019-03-26 17:44:40 +0000250 const auto f0 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, f);
251 const auto f1 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, f);
252 const auto f2 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, f);
253 const auto f3 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, f);
254 const auto f4 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, f);
255 const auto f5 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000256
giuros0105fb4482019-03-26 17:44:40 +0000257 const auto g0 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, g);
258 const auto g1 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, g);
259 const auto g2 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, g);
260 const auto g3 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, g);
261 const auto g4 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, g);
262 const auto g5 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, g);
giuros0114c4e0f2019-03-26 17:44:40 +0000263
264 x1 = reduce_sum_7(a, b, c, d, e, f, g);
265 x2 = reduce_sum_7(a, b0, c0, d0, e0, f0, g0);
266 x3 = reduce_sum_7(a, b1, c1, d1, e1, f1, g1);
267 x4 = reduce_sum_7(a, b2, c2, d2, e2, f2, g2);
268 x5 = reduce_sum_7(a, b3, c3, d3, e3, f3, g3);
269 x6 = reduce_sum_7(a, b4, c4, d4, e4, f4, g4);
270 x7 = reduce_sum_7(a, b5, c5, d5, e5, f5, g5);
271}
272
giuros0114c4e0f2019-03-26 17:44:40 +0000273void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, float32x2_t &x8, const float32x2_t &w, const float32x2_t &w2,
274 const float32x2_t &w3,
275 const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6,
276 const float32x2_t &w7)
277{
278 const auto a = x1;
279 const auto b = c_mul_neon(w, x2);
280 const auto c = c_mul_neon(w2, x3);
281 const auto d = c_mul_neon(w3, x4);
282 const auto e = c_mul_neon(w4, x5);
283 const auto f = c_mul_neon(w5, x6);
284 const auto g = c_mul_neon(w6, x7);
285 const auto h = c_mul_neon(w7, x8);
286
giuros0105fb4482019-03-26 17:44:40 +0000287 const auto b0 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000288 const auto b1 = c_mul_neon(float32x2_t{ 0, -1 }, b);
giuros0105fb4482019-03-26 17:44:40 +0000289 const auto b2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000290 const auto b3 = c_mul_neon(float32x2_t{ -1, 0 }, b);
giuros0105fb4482019-03-26 17:44:40 +0000291 const auto b4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000292 const auto b5 = c_mul_neon(float32x2_t{ 0, 1 }, b);
giuros0105fb4482019-03-26 17:44:40 +0000293 const auto b6 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000294
295 const auto c0 = c_mul_neon(float32x2_t{ 0, -1 }, c);
296 const auto c1 = c_mul_neon(float32x2_t{ -1, 0 }, c);
297 const auto c2 = c_mul_neon(float32x2_t{ 0, 1 }, c);
298 const auto c3 = c_mul_neon(float32x2_t{ 1, 0 }, c);
299 const auto c4 = c_mul_neon(float32x2_t{ 0, -1 }, c);
300 const auto c5 = c_mul_neon(float32x2_t{ -1, 0 }, c);
301 const auto c6 = c_mul_neon(float32x2_t{ 0, 1 }, c);
302
giuros0105fb4482019-03-26 17:44:40 +0000303 const auto d0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000304 const auto d1 = c_mul_neon(float32x2_t{ 0, 1 }, d);
giuros0105fb4482019-03-26 17:44:40 +0000305 const auto d2 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000306 const auto d3 = c_mul_neon(float32x2_t{ -1, 0 }, d);
giuros0105fb4482019-03-26 17:44:40 +0000307 const auto d4 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000308 const auto d5 = c_mul_neon(float32x2_t{ 0, -1 }, d);
giuros0105fb4482019-03-26 17:44:40 +0000309 const auto d6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000310
311 const auto e0 = c_mul_neon(float32x2_t{ -1, 0 }, e);
312 const auto e1 = c_mul_neon(float32x2_t{ 1, 0 }, e);
313 const auto e2 = c_mul_neon(float32x2_t{ -1, 0 }, e);
314 const auto e3 = c_mul_neon(float32x2_t{ 1, 0 }, e);
315 const auto e4 = c_mul_neon(float32x2_t{ -1, 0 }, e);
316 const auto e5 = c_mul_neon(float32x2_t{ 1, 0 }, e);
317 const auto e6 = c_mul_neon(float32x2_t{ -1, 0 }, e);
318
giuros0105fb4482019-03-26 17:44:40 +0000319 const auto f0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000320 const auto f1 = c_mul_neon(float32x2_t{ 0, -1 }, f);
giuros0105fb4482019-03-26 17:44:40 +0000321 const auto f2 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000322 const auto f3 = c_mul_neon(float32x2_t{ -1, 0 }, f);
giuros0105fb4482019-03-26 17:44:40 +0000323 const auto f4 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000324 const auto f5 = c_mul_neon(float32x2_t{ 0, 1 }, f);
giuros0105fb4482019-03-26 17:44:40 +0000325 const auto f6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000326
327 const auto g0 = c_mul_neon(float32x2_t{ 0, 1 }, g);
328 const auto g1 = c_mul_neon(float32x2_t{ -1, 0 }, g);
329 const auto g2 = c_mul_neon(float32x2_t{ 0, -1 }, g);
330 const auto g3 = c_mul_neon(float32x2_t{ 1, 0 }, g);
331 const auto g4 = c_mul_neon(float32x2_t{ 0, 1 }, g);
332 const auto g5 = c_mul_neon(float32x2_t{ -1, 0 }, g);
333 const auto g6 = c_mul_neon(float32x2_t{ 0, -1 }, g);
334
giuros0105fb4482019-03-26 17:44:40 +0000335 const auto h0 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, h);
giuros0114c4e0f2019-03-26 17:44:40 +0000336 const auto h1 = c_mul_neon(float32x2_t{ 0, 1 }, h);
giuros0105fb4482019-03-26 17:44:40 +0000337 const auto h2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, h);
giuros0114c4e0f2019-03-26 17:44:40 +0000338 const auto h3 = c_mul_neon(float32x2_t{ -1, 0 }, h);
giuros0105fb4482019-03-26 17:44:40 +0000339 const auto h4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, h);
giuros0114c4e0f2019-03-26 17:44:40 +0000340 const auto h5 = c_mul_neon(float32x2_t{ 0, -1 }, h);
giuros0105fb4482019-03-26 17:44:40 +0000341 const auto h6 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, h);
giuros0114c4e0f2019-03-26 17:44:40 +0000342
343 x1 = reduce_sum_8(a, b, c, d, e, f, g, h);
344 x2 = reduce_sum_8(a, b0, c0, d0, e0, f0, g0, h0);
345 x3 = reduce_sum_8(a, b1, c1, d1, e1, f1, g1, h1);
346 x4 = reduce_sum_8(a, b2, c2, d2, e2, f2, g2, h2);
347 x5 = reduce_sum_8(a, b3, c3, d3, e3, f3, g3, h3);
348 x6 = reduce_sum_8(a, b4, c4, d4, e4, f4, g4, h4);
349 x7 = reduce_sum_8(a, b5, c5, d5, e5, f5, g5, h5);
350 x8 = reduce_sum_8(a, b6, c6, d6, e6, f6, g6, h6);
351}
352
353template <bool first_stage>
giuros0105fb4482019-03-26 17:44:40 +0000354void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000355{
giuros0105fb4482019-03-26 17:44:40 +0000356 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000357 for(unsigned int j = 0; j < Nx; j++)
358 {
giuros0105fb4482019-03-26 17:44:40 +0000359 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000360 {
361 auto a = float32x2_t{ 0, 0 };
362 auto b = float32x2_t{ 0, 0 };
363
364 // Load inputs
365 if(first_stage)
366 {
367 const auto ab = wrapper::vloadq(x + k);
368 a = wrapper::vgetlow(ab);
369 b = wrapper::vgethigh(ab);
370 }
371 else
372 {
373 a = wrapper::vload(x + k);
374 b = wrapper::vload(x + k + 2 * Nx);
375 }
376
377 // Base-case prime transform
378 fft_2(a, b, w);
379
380 // Write outputs
381 if(first_stage)
382 {
383 wrapper::vstore(X + k, wrapper::vcombine(a, b));
384 }
385 else
386 {
387 wrapper::vstore(X + k, a);
388 wrapper::vstore(X + k + 2 * Nx, b);
389 }
390 }
391
392 w = c_mul_neon(w, w_m);
393 }
394}
395
giuros0105fb4482019-03-26 17:44:40 +0000396void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000397{
giuros0105fb4482019-03-26 17:44:40 +0000398 float32x2_t w{ 1.0f, 0.0f };
399 for(unsigned int j = 0; j < Nx; j++)
400 {
401 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
402 {
403 // Load inputs
404 float32x2_t a = wrapper::vload(x + M * k);
405 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000406
giuros0105fb4482019-03-26 17:44:40 +0000407 // Base-case prime transform
408 fft_2(a, b, w);
409
410 // Write outputs
411 wrapper::vstore(X + M * k, a);
412 wrapper::vstore(X + M * (k + 2 * Nx), b);
413 }
414
415 w = c_mul_neon(w, w_m);
416 }
417}
418
419template <bool first_stage>
420void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
421{
422 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000423 for(unsigned int j = 0; j < Nx; j++)
424 {
425 const auto w2 = c_mul_neon(w, w);
426
giuros0105fb4482019-03-26 17:44:40 +0000427 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000428 {
429 // Load inputs
430 float32x2_t a = { 0, 0 };
431 float32x2_t b = { 0, 0 };
432 float32x2_t c = { 0, 0 };
433 if(first_stage)
434 {
435 const auto ab = wrapper::vloadq(x + k);
436 a = wrapper::vgetlow(ab);
437 b = wrapper::vgethigh(ab);
438 }
439 else
440 {
441 a = wrapper::vload(x + k);
442 b = wrapper::vload(x + k + 2 * Nx);
443 }
444 c = wrapper::vload(x + k + 4 * Nx);
445
446 // Base-case prime transform
447 fft_3(a, b, c, w, w2);
448
449 if(first_stage)
450 {
451 wrapper::vstore(X + k, wrapper::vcombine(a, b));
452 }
453 else
454 {
455 wrapper::vstore(X + k, a);
456 wrapper::vstore(X + k + 2 * Nx, b);
457 }
458 wrapper::vstore(X + k + 4 * Nx, c);
459 }
460 w = c_mul_neon(w, w_m);
461 }
462}
463
giuros0105fb4482019-03-26 17:44:40 +0000464void fft_radix_3_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000465{
giuros0105fb4482019-03-26 17:44:40 +0000466 float32x2_t w{ 1.0f, 0.0f };
467 for(unsigned int j = 0; j < Nx; j++)
468 {
469 const auto w2 = c_mul_neon(w, w);
giuros0114c4e0f2019-03-26 17:44:40 +0000470
giuros0105fb4482019-03-26 17:44:40 +0000471 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
472 {
473 // Load inputs
474 float32x2_t a = wrapper::vload(x + M * k);
475 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
476 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000477
giuros0105fb4482019-03-26 17:44:40 +0000478 // Base-case prime transform
479 fft_3(a, b, c, w, w2);
480
481 // Store the output
482 wrapper::vstore(X + M * k, a);
483 wrapper::vstore(X + M * (k + 2 * Nx), b);
484 wrapper::vstore(X + M * (k + 4 * Nx), c);
485 }
486 w = c_mul_neon(w, w_m);
487 }
488}
489
490template <bool first_stage>
491void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
492{
493 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000494 for(unsigned int j = 0; j < Nx; j++)
495 {
496 const auto w2 = c_mul_neon(w, w);
497 const auto w3 = c_mul_neon(w2, w);
498
giuros0105fb4482019-03-26 17:44:40 +0000499 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000500 {
501 float32x2_t a = { 0, 0 };
502 float32x2_t b = { 0, 0 };
503 float32x2_t c = { 0, 0 };
504 float32x2_t d = { 0, 0 };
505 if(first_stage)
506 {
507 const auto ab = wrapper::vloadq(x + k);
508 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
509 a = wrapper::vgetlow(ab);
510 b = wrapper::vgethigh(ab);
511 c = wrapper::vgetlow(cd);
512 d = wrapper::vgethigh(cd);
513 }
514 else
515 {
516 // Load inputs
517 a = wrapper::vload(x + k);
518 b = wrapper::vload(x + k + 2 * Nx);
519 c = wrapper::vload(x + k + 4 * Nx);
520 d = wrapper::vload(x + k + 6 * Nx);
521 }
522
523 // Base-case prime transform
524 fft_4(a, b, c, d, w, w2, w3);
525
526 if(first_stage)
527 {
528 wrapper::vstore(X + k, wrapper::vcombine(a, b));
529 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
530 }
531 else
532 {
533 wrapper::vstore(X + k, a);
534 wrapper::vstore(X + k + 2 * Nx, b);
535 wrapper::vstore(X + k + 4 * Nx, c);
536 wrapper::vstore(X + k + 6 * Nx, d);
537 }
538 }
539
540 w = c_mul_neon(w, w_m);
541 }
542}
543
giuros0105fb4482019-03-26 17:44:40 +0000544void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000545{
giuros0105fb4482019-03-26 17:44:40 +0000546 float32x2_t w{ 1.0f, 0.0f };
547 for(unsigned int j = 0; j < Nx; j++)
548 {
549 const auto w2 = c_mul_neon(w, w);
550 const auto w3 = c_mul_neon(w2, w);
giuros0114c4e0f2019-03-26 17:44:40 +0000551
giuros0105fb4482019-03-26 17:44:40 +0000552 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
553 {
554 // Load inputs
555 float32x2_t a = wrapper::vload(x + M * k);
556 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
557 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
558 float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000559
giuros0105fb4482019-03-26 17:44:40 +0000560 // Base-case prime transform
561 fft_4(a, b, c, d, w, w2, w3);
562
563 wrapper::vstore(X + M * k, a);
564 wrapper::vstore(X + M * (k + 2 * Nx), b);
565 wrapper::vstore(X + M * (k + 4 * Nx), c);
566 wrapper::vstore(X + M * (k + 6 * Nx), d);
567 }
568
569 w = c_mul_neon(w, w_m);
570 }
571}
572
573template <bool first_stage>
574void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
575{
576 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000577 for(unsigned int j = 0; j < Nx; j++)
578 {
579 const float32x2_t w2 = c_mul_neon(w, w);
580 const float32x2_t w3 = c_mul_neon(w2, w);
581 const float32x2_t w4 = c_mul_neon(w3, w);
582
giuros0105fb4482019-03-26 17:44:40 +0000583 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000584 {
585 float32x2_t a = { 0, 0 };
586 float32x2_t b = { 0, 0 };
587 float32x2_t c = { 0, 0 };
588 float32x2_t d = { 0, 0 };
589 float32x2_t e = { 0, 0 };
590
591 // Load inputs
592 if(first_stage)
593 {
594 const auto ab = wrapper::vloadq(x + k);
595 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
596
597 a = wrapper::vgetlow(ab);
598 b = wrapper::vgethigh(ab);
599 c = wrapper::vgetlow(cd);
600 d = wrapper::vgethigh(cd);
601 }
602 else
603 {
604 a = wrapper::vload(x + k);
605 b = wrapper::vload(x + k + 2 * Nx);
606 c = wrapper::vload(x + k + 4 * Nx);
607 d = wrapper::vload(x + k + 6 * Nx);
608 }
609 e = wrapper::vload(x + k + 8 * Nx);
610
611 // Base-case prime transform
612 fft_5(a, b, c, d, e, w, w2, w3, w4);
613
614 // Store outputs
615 if(first_stage)
616 {
617 wrapper::vstore(X + k, wrapper::vcombine(a, b));
618 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
619 }
620 else
621 {
622 wrapper::vstore(X + k, a);
623 wrapper::vstore(X + k + 2 * Nx, b);
624 wrapper::vstore(X + k + 4 * Nx, c);
625 wrapper::vstore(X + k + 6 * Nx, d);
626 }
627 wrapper::vstore(X + k + 8 * Nx, e);
628 }
629
630 w = c_mul_neon(w, w_m);
631 }
632}
633
giuros0105fb4482019-03-26 17:44:40 +0000634void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000635{
giuros0105fb4482019-03-26 17:44:40 +0000636 float32x2_t w{ 1.0f, 0.0f };
637 for(unsigned int j = 0; j < Nx; j++)
638 {
639 const float32x2_t w2 = c_mul_neon(w, w);
640 const float32x2_t w3 = c_mul_neon(w2, w);
641 const float32x2_t w4 = c_mul_neon(w3, w);
giuros0114c4e0f2019-03-26 17:44:40 +0000642
giuros0105fb4482019-03-26 17:44:40 +0000643 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
644 {
645 // Load inputs
646 float32x2_t a = wrapper::vload(x + M * k);
647 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
648 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
649 float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
650 float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000651
giuros0105fb4482019-03-26 17:44:40 +0000652 // Base-case prime transform
653 fft_5(a, b, c, d, e, w, w2, w3, w4);
654
655 // Store outputs
656 wrapper::vstore(X + M * k, a);
657 wrapper::vstore(X + M * (k + 2 * Nx), b);
658 wrapper::vstore(X + M * (k + 4 * Nx), c);
659 wrapper::vstore(X + M * (k + 6 * Nx), d);
660 wrapper::vstore(X + M * (k + 8 * Nx), e);
661 }
662
663 w = c_mul_neon(w, w_m);
664 }
665}
666
667template <bool first_stage>
668void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
669{
670 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000671 for(unsigned int j = 0; j < Nx; j++)
672 {
673 const float32x2_t w2 = c_mul_neon(w, w);
674 const float32x2_t w3 = c_mul_neon(w2, w);
675 const float32x2_t w4 = c_mul_neon(w3, w);
676 const float32x2_t w5 = c_mul_neon(w4, w);
677 const float32x2_t w6 = c_mul_neon(w5, w);
678
giuros0105fb4482019-03-26 17:44:40 +0000679 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000680 {
681 float32x2_t a = { 0, 0 };
682 float32x2_t b = { 0, 0 };
683 float32x2_t c = { 0, 0 };
684 float32x2_t d = { 0, 0 };
685 float32x2_t e = { 0, 0 };
686 float32x2_t f = { 0, 0 };
687 float32x2_t g = { 0, 0 };
688
689 // Load inputs
690 if(first_stage)
691 {
692 const auto ab = wrapper::vloadq(x + k);
693 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
694 const auto ef = wrapper::vloadq(x + k + 8 * Nx);
695
696 a = wrapper::vgetlow(ab);
697 b = wrapper::vgethigh(ab);
698 c = wrapper::vgetlow(cd);
699 d = wrapper::vgethigh(cd);
700 e = wrapper::vgetlow(ef);
701 f = wrapper::vgethigh(ef);
702 }
703 else
704 {
705 a = wrapper::vload(x + k);
706 b = wrapper::vload(x + k + 2 * Nx);
707 c = wrapper::vload(x + k + 4 * Nx);
708 d = wrapper::vload(x + k + 6 * Nx);
709 e = wrapper::vload(x + k + 8 * Nx);
710 f = wrapper::vload(x + k + 10 * Nx);
711 }
712 g = wrapper::vload(x + k + 12 * Nx);
713
714 // Base-case prime transform
715 fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
716
717 if(first_stage)
718 {
719 wrapper::vstore(X + k, wrapper::vcombine(a, b));
720 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
721 wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
722 }
723 else
724 {
725 wrapper::vstore(X + k, a);
726 wrapper::vstore(X + k + 2 * Nx, b);
727 wrapper::vstore(X + k + 4 * Nx, c);
728 wrapper::vstore(X + k + 6 * Nx, d);
729 wrapper::vstore(X + k + 8 * Nx, e);
730 wrapper::vstore(X + k + 10 * Nx, f);
731 }
732 wrapper::vstore(X + k + 12 * Nx, g);
733 }
734
735 w = c_mul_neon(w, w_m);
736 }
737}
738
giuros0105fb4482019-03-26 17:44:40 +0000739void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000740{
giuros0105fb4482019-03-26 17:44:40 +0000741 float32x2_t w{ 1.0f, 0.0f };
742 for(unsigned int j = 0; j < Nx; j++)
743 {
744 const float32x2_t w2 = c_mul_neon(w, w);
745 const float32x2_t w3 = c_mul_neon(w2, w);
746 const float32x2_t w4 = c_mul_neon(w3, w);
747 const float32x2_t w5 = c_mul_neon(w4, w);
748 const float32x2_t w6 = c_mul_neon(w5, w);
giuros0114c4e0f2019-03-26 17:44:40 +0000749
giuros0105fb4482019-03-26 17:44:40 +0000750 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
751 {
752 // Load inputs
753 float32x2_t a = wrapper::vload(x + M * k);
754 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
755 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
756 float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
757 float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
758 float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
759 float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000760
giuros0105fb4482019-03-26 17:44:40 +0000761 // Base-case prime transform
762 fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
763
764 // Store outputs
765 wrapper::vstore(X + M * k, a);
766 wrapper::vstore(X + M * (k + 2 * Nx), b);
767 wrapper::vstore(X + M * (k + 4 * Nx), c);
768 wrapper::vstore(X + M * (k + 6 * Nx), d);
769 wrapper::vstore(X + M * (k + 8 * Nx), e);
770 wrapper::vstore(X + M * (k + 10 * Nx), f);
771 wrapper::vstore(X + M * (k + 12 * Nx), g);
772 }
773
774 w = c_mul_neon(w, w_m);
775 }
776}
777
778template <bool first_stage>
779void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
780{
781 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000782 for(unsigned int j = 0; j < Nx; j++)
783 {
784 const float32x2_t w2 = c_mul_neon(w, w);
785 const float32x2_t w3 = c_mul_neon(w2, w);
786 const float32x2_t w4 = c_mul_neon(w3, w);
787 const float32x2_t w5 = c_mul_neon(w4, w);
788 const float32x2_t w6 = c_mul_neon(w5, w);
789 const float32x2_t w7 = c_mul_neon(w6, w);
790
giuros0105fb4482019-03-26 17:44:40 +0000791 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000792 {
793 // Load inputs
794 float32x2_t a = { 0, 0 };
795 float32x2_t b = { 0, 0 };
796 float32x2_t c = { 0, 0 };
797 float32x2_t d = { 0, 0 };
798 float32x2_t e = { 0, 0 };
799 float32x2_t f = { 0, 0 };
800 float32x2_t g = { 0, 0 };
801 float32x2_t h = { 0, 0 };
802
803 // Base-case prime transform
804 if(first_stage)
805 {
806 const auto ab = wrapper::vloadq(x + k);
807 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
808 const auto ef = wrapper::vloadq(x + k + 8 * Nx);
809 const auto gh = wrapper::vloadq(x + k + 12 * Nx);
810
811 a = wrapper::vgetlow(ab);
812 b = wrapper::vgethigh(ab);
813 c = wrapper::vgetlow(cd);
814 d = wrapper::vgethigh(cd);
815 e = wrapper::vgetlow(ef);
816 f = wrapper::vgethigh(ef);
817 g = wrapper::vgetlow(gh);
818 h = wrapper::vgethigh(gh);
819 }
820 else
821 {
822 a = wrapper::vload(x + k);
823 b = wrapper::vload(x + k + 2 * Nx);
824 c = wrapper::vload(x + k + 4 * Nx);
825 d = wrapper::vload(x + k + 6 * Nx);
826 e = wrapper::vload(x + k + 8 * Nx);
827 f = wrapper::vload(x + k + 10 * Nx);
828 g = wrapper::vload(x + k + 12 * Nx);
829 h = wrapper::vload(x + k + 14 * Nx);
830 }
831
832 // Apply twiddle factors
833 fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
834
835 // Store outputs
836 if(first_stage)
837 {
838 wrapper::vstore(X + k, wrapper::vcombine(a, b));
839 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
840 wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
841 wrapper::vstore(X + k + 12 * Nx, wrapper::vcombine(g, h));
842 }
843 else
844 {
845 wrapper::vstore(X + k, a);
846 wrapper::vstore(X + k + 2 * Nx, b);
847 wrapper::vstore(X + k + 4 * Nx, c);
848 wrapper::vstore(X + k + 6 * Nx, d);
849 wrapper::vstore(X + k + 8 * Nx, e);
850 wrapper::vstore(X + k + 10 * Nx, f);
851 wrapper::vstore(X + k + 12 * Nx, g);
852 wrapper::vstore(X + k + 14 * Nx, h);
853 }
854 }
855
856 w = c_mul_neon(w, w_m);
857 }
858}
859
giuros0105fb4482019-03-26 17:44:40 +0000860void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
861{
862 float32x2_t w{ 1.0f, 0.0f };
863 for(unsigned int j = 0; j < Nx; j++)
864 {
865 const float32x2_t w2 = c_mul_neon(w, w);
866 const float32x2_t w3 = c_mul_neon(w2, w);
867 const float32x2_t w4 = c_mul_neon(w3, w);
868 const float32x2_t w5 = c_mul_neon(w4, w);
869 const float32x2_t w6 = c_mul_neon(w5, w);
870 const float32x2_t w7 = c_mul_neon(w6, w);
871
872 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
873 {
874 // Load inputs
875 float32x2_t a = wrapper::vload(x + M * k);
876 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
877 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
878 float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
879 float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
880 float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
881 float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
882 float32x2_t h = wrapper::vload(x + M * (k + 14 * Nx));
883
884 // Base-case prime transform
885 fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
886
887 // Store outputs
888 wrapper::vstore(X + M * k, a);
889 wrapper::vstore(X + M * (k + 2 * Nx), b);
890 wrapper::vstore(X + M * (k + 4 * Nx), c);
891 wrapper::vstore(X + M * (k + 6 * Nx), d);
892 wrapper::vstore(X + M * (k + 8 * Nx), e);
893 wrapper::vstore(X + M * (k + 10 * Nx), f);
894 wrapper::vstore(X + M * (k + 12 * Nx), g);
895 wrapper::vstore(X + M * (k + 14 * Nx), h);
896 }
897
898 w = c_mul_neon(w, w_m);
899 }
900}
901
giuros0114c4e0f2019-03-26 17:44:40 +0000902Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
903{
904 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
giuros0105fb4482019-03-26 17:44:40 +0000905 ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1);
giuros0114c4e0f2019-03-26 17:44:40 +0000906 ARM_COMPUTE_RETURN_ERROR_ON(NEFFTRadixStageKernel::supported_radix().count(config.radix) == 0);
giuros0105fb4482019-03-26 17:44:40 +0000907 ARM_COMPUTE_UNUSED(config);
giuros0114c4e0f2019-03-26 17:44:40 +0000908
909 // Checks performed when output is configured
910 if((output != nullptr) && (output->total_size() != 0))
911 {
912 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
913 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
914 }
915
916 return Status{};
917}
918
919std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
920{
giuros0105fb4482019-03-26 17:44:40 +0000921 ARM_COMPUTE_UNUSED(config);
922
giuros0114c4e0f2019-03-26 17:44:40 +0000923 if(output != nullptr)
924 {
925 auto_init_if_empty(*output, *input);
926 }
927
giuros0105fb4482019-03-26 17:44:40 +0000928 Window win = calculate_max_window(*input, Steps());
giuros0114c4e0f2019-03-26 17:44:40 +0000929 if(output != nullptr)
930 {
931 output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
932 }
933
934 return std::make_pair(Status{}, win);
935}
936} // namespace
937
938NEFFTRadixStageKernel::NEFFTRadixStageKernel()
giuros0105fb4482019-03-26 17:44:40 +0000939 : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
giuros0114c4e0f2019-03-26 17:44:40 +0000940{
941}
942
giuros0105fb4482019-03-26 17:44:40 +0000943void NEFFTRadixStageKernel::set_radix_stage_axis0(const FFTRadixStageKernelInfo &config)
giuros0114c4e0f2019-03-26 17:44:40 +0000944{
giuros0105fb4482019-03-26 17:44:40 +0000945 // FFT table axis 0: [radix, first_stage]
946 static std::map<unsigned int, std::map<bool, FFTFunctionPointerAxis0>> fft_table_axis0;
947
948 if(fft_table_axis0.empty())
giuros0114c4e0f2019-03-26 17:44:40 +0000949 {
giuros0105fb4482019-03-26 17:44:40 +0000950 fft_table_axis0[2][false] = &fft_radix_2_axes_0<false>;
951 fft_table_axis0[3][false] = &fft_radix_3_axes_0<false>;
952 fft_table_axis0[4][false] = &fft_radix_4_axes_0<false>;
953 fft_table_axis0[5][false] = &fft_radix_5_axes_0<false>;
954 fft_table_axis0[7][false] = &fft_radix_7_axes_0<false>;
955 fft_table_axis0[8][false] = &fft_radix_8_axes_0<false>;
956
957 fft_table_axis0[2][true] = &fft_radix_2_axes_0<true>;
958 fft_table_axis0[3][true] = &fft_radix_3_axes_0<true>;
959 fft_table_axis0[4][true] = &fft_radix_4_axes_0<true>;
960 fft_table_axis0[5][true] = &fft_radix_5_axes_0<true>;
961 fft_table_axis0[7][true] = &fft_radix_7_axes_0<true>;
962 fft_table_axis0[8][true] = &fft_radix_8_axes_0<true>;
giuros0114c4e0f2019-03-26 17:44:40 +0000963 }
giuros0105fb4482019-03-26 17:44:40 +0000964
965 _func_0 = fft_table_axis0[config.radix][config.is_first_stage];
966}
967
968void NEFFTRadixStageKernel::set_radix_stage_axis1(const FFTRadixStageKernelInfo &config)
969{
970 // FFT table axis 1: [radix, first_stage]
971 static std::map<unsigned int, FFTFunctionPointerAxis1> fft_table_axis1;
972
973 if(fft_table_axis1.empty())
974 {
975 fft_table_axis1[2] = &fft_radix_2_axes_1;
976 fft_table_axis1[3] = &fft_radix_3_axes_1;
977 fft_table_axis1[4] = &fft_radix_4_axes_1;
978 fft_table_axis1[5] = &fft_radix_5_axes_1;
979 fft_table_axis1[7] = &fft_radix_7_axes_1;
980 fft_table_axis1[8] = &fft_radix_8_axes_1;
981 }
982
983 _func_1 = fft_table_axis1[config.radix];
giuros0114c4e0f2019-03-26 17:44:40 +0000984}
985
986void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFTRadixStageKernelInfo &config)
987{
988 ARM_COMPUTE_ERROR_ON_NULLPTR(input);
989
990 // Output auto inizialitation if not yet initialized
991 if(output != nullptr)
992 {
993 auto_init_if_empty(*output->info(), *input->info()->clone());
994 }
995
996 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
997
998 _input = input;
999 _output = output;
1000 _run_in_place = (output == nullptr) || (output == input);
1001 _Nx = config.Nx;
giuros0105fb4482019-03-26 17:44:40 +00001002 _axis = config.axis;
1003 _radix = config.radix;
giuros0114c4e0f2019-03-26 17:44:40 +00001004
giuros0105fb4482019-03-26 17:44:40 +00001005 switch(config.axis)
giuros0114c4e0f2019-03-26 17:44:40 +00001006 {
giuros0105fb4482019-03-26 17:44:40 +00001007 case 0:
1008 set_radix_stage_axis0(config);
1009 break;
1010 case 1:
1011 set_radix_stage_axis1(config);
1012 break;
1013 default:
1014 ARM_COMPUTE_ERROR("Axis not supported");
1015 break;
giuros0114c4e0f2019-03-26 17:44:40 +00001016 }
1017
1018 // Configure kernel window
1019 auto win_config = validate_and_configure_window(input->info(), (_run_in_place) ? nullptr : output->info(), config);
1020 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1021 INEKernel::configure(win_config.second);
1022}
1023
1024Status NEFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
1025{
1026 const bool run_in_place = (output == nullptr) || (output == input);
1027 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config));
1028 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1029 (run_in_place) ? nullptr : output->clone().get(),
1030 config)
1031 .first);
1032
1033 return Status{};
1034}
1035
1036std::set<unsigned int> NEFFTRadixStageKernel::supported_radix()
1037{
1038 return std::set<unsigned int> { 2, 3, 4, 5, 7, 8 };
1039}
1040
1041void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
1042{
giuros0114c4e0f2019-03-26 17:44:40 +00001043 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1044 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
giuros0105fb4482019-03-26 17:44:40 +00001045 ARM_COMPUTE_UNUSED(info);
giuros0114c4e0f2019-03-26 17:44:40 +00001046
1047 Window input_window = window;
giuros0105fb4482019-03-26 17:44:40 +00001048 input_window.set(_axis, 0);
giuros0114c4e0f2019-03-26 17:44:40 +00001049
1050 Iterator in(_input, input_window);
1051 Iterator out(_run_in_place ? _input : _output, input_window);
1052
giuros0105fb4482019-03-26 17:44:40 +00001053 // Precompute FFT constants
1054 const unsigned int NxRadix = _radix * _Nx;
1055 const float alpha = 2.0f * kPi / float(NxRadix);
1056 const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
1057
1058 if(_axis == 0)
giuros0114c4e0f2019-03-26 17:44:40 +00001059 {
giuros0105fb4482019-03-26 17:44:40 +00001060 const unsigned int N = _input->info()->dimension(0);
1061 execute_window_loop(input_window, [&](const Coordinates &)
1062 {
1063 _func_0(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N);
1064 },
1065 in, out);
1066 }
1067 else
1068 {
1069 const unsigned int N = _input->info()->dimension(0);
1070 const unsigned int M = _input->info()->dimension(1);
1071 execute_window_loop(input_window, [&](const Coordinates &)
1072 {
1073 _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N, M);
1074 },
1075 in, out);
1076 }
giuros0114c4e0f2019-03-26 17:44:40 +00001077
1078 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1079 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1080}
1081} // namespace arm_compute