blob: 148bbe915ae149975cbe1f54eee19debfc9b109c [file] [log] [blame]
giuros0114c4e0f2019-03-26 17:44:40 +00001/*
2 * Copyright (c) 2019 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "arm_compute/core/NEON/kernels/NEFFTRadixStageKernel.h"
25
26#include "arm_compute/core/ITensor.h"
giuros0114c4e0f2019-03-26 17:44:40 +000027#include "arm_compute/core/TensorInfo.h"
28#include "arm_compute/core/Types.h"
29#include "arm_compute/core/Utils.h"
30#include "arm_compute/core/Window.h"
31
32#include <arm_neon.h>
33#include <cmath>
34#include <complex>
giuros0105fb4482019-03-26 17:44:40 +000035#include <map>
36
37#include "arm_compute/core/NEON/wrapper/traits.h"
38#include "arm_compute/core/NEON/wrapper/wrapper.h"
giuros0114c4e0f2019-03-26 17:44:40 +000039
40namespace arm_compute
41{
42namespace
43{
giuros0105fb4482019-03-26 17:44:40 +000044// PI constant (from cmath)
45constexpr float kPi = float(M_PI);
46
47// Constant used in the fft_3 kernel
48constexpr float kSqrt3Div2 = 0.866025403784438;
49
50// Constants used in the fft_5 kernel
51constexpr float kW5_0 = 0.30901699437494f;
52constexpr float kW5_1 = 0.95105651629515f;
53constexpr float kW5_2 = 0.80901699437494f;
54constexpr float kW5_3 = 0.58778525229247f;
55
56// Constants used in the fft_7 kernel
57constexpr float kW7_0 = 0.62348980185873f;
58constexpr float kW7_1 = 0.78183148246802f;
59constexpr float kW7_2 = 0.22252093395631f;
60constexpr float kW7_3 = 0.97492791218182f;
61constexpr float kW7_4 = 0.90096886790241f;
62constexpr float kW7_5 = 0.43388373911755f;
63
64// Constant used in the fft_8 kernel
65constexpr float kSqrt2Div2 = 0.707106781186548;
giuros0114c4e0f2019-03-26 17:44:40 +000066
67float32x2_t c_mul_neon(float32x2_t a, float32x2_t b)
68{
giuros0105fb4482019-03-26 17:44:40 +000069 using ExactTagType = typename wrapper::traits::neon_vector<float, 2>::tag_type;
giuros0114c4e0f2019-03-26 17:44:40 +000070
giuros0105fb4482019-03-26 17:44:40 +000071 const float32x2_t mask = { -1.0, 1.0 };
72 const float32x2_t tmp0 = wrapper::vdup_n(wrapper::vgetlane(a, 0), ExactTagType{});
73 const float32x2_t tmp1 = wrapper::vdup_n(wrapper::vgetlane(a, 1), ExactTagType{});
giuros0114c4e0f2019-03-26 17:44:40 +000074
giuros0105fb4482019-03-26 17:44:40 +000075 float32x2_t res = wrapper::vmul(tmp0, b);
giuros0114c4e0f2019-03-26 17:44:40 +000076
giuros0105fb4482019-03-26 17:44:40 +000077 b = wrapper::vrev64(b);
78 b = wrapper::vmul(b, mask);
79 res = wrapper::vmla(res, tmp1, b);
80
81 return res;
giuros0114c4e0f2019-03-26 17:44:40 +000082}
83
84float32x2_t c_mul_neon_img(float32x2_t a, float img_constant)
85{
86 const float a_r = wrapper::vgetlane(a, 0);
87 const float a_i = wrapper::vgetlane(a, 1);
88
89 const auto out = wrapper::vmul(float32x2_t{ -a_i, a_r }, float32x2_t{ img_constant, img_constant });
90 return out;
91}
92
93float32x2_t reduce_sum_5(float32x2_t a, float32x2_t b, float32x2_t c, float32x2_t d, float32x2_t e)
94{
95 const auto t0 = wrapper::vadd(a, b);
96 const auto t1 = wrapper::vadd(c, d);
97 const auto t2 = wrapper::vadd(t0, t1);
98 return wrapper::vadd(t2, e);
99}
100
101float32x2_t reduce_sum_7(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7)
102{
103 const auto t0 = wrapper::vadd(x1, x2);
104 const auto t1 = wrapper::vadd(x3, x4);
105 const auto t2 = wrapper::vadd(x5, x6);
106 const auto t00 = wrapper::vadd(t0, t1);
107 const auto t01 = wrapper::vadd(t2, x7);
108
109 return wrapper::vadd(t00, t01);
110}
111
112float32x2_t reduce_sum_8(float32x2_t x1, float32x2_t x2, float32x2_t x3, float32x2_t x4, float32x2_t x5, float32x2_t x6, float32x2_t x7, float32x2_t x8)
113{
114 const auto t0 = wrapper::vadd(x1, x2);
115 const auto t1 = wrapper::vadd(x3, x4);
116 const auto t2 = wrapper::vadd(x5, x6);
117 const auto t3 = wrapper::vadd(x7, x8);
118 const auto t00 = wrapper::vadd(t0, t1);
119 const auto t01 = wrapper::vadd(t2, t3);
120
121 return wrapper::vadd(t00, t01);
122}
123
124void fft_2(float32x2_t &x, float32x2_t &y, float32x2_t &w)
125{
126 float32x2_t a = x;
127 float32x2_t b = c_mul_neon(w, y);
128
129 x = wrapper::vadd(a, b);
130 y = wrapper::vsub(a, b);
131}
132
giuros0114c4e0f2019-03-26 17:44:40 +0000133void fft_3(float32x2_t &x, float32x2_t &y, float32x2_t &z, const float32x2_t &w, const float32x2_t &w2)
134{
135 float32x2_t a = x;
136 float32x2_t b = c_mul_neon(w, y);
137 float32x2_t c = c_mul_neon(w2, z);
138
139 x = wrapper::vadd(a, b);
140 x = wrapper::vadd(x, c);
141
142 const auto v1 = wrapper::vmul(float32x2_t{ 0.5f, 0.5 }, wrapper::vadd(b, c));
giuros0105fb4482019-03-26 17:44:40 +0000143 const auto v2 = c_mul_neon(float32x2_t{ 0.f, -kSqrt3Div2 }, wrapper::vsub(b, c));
giuros0114c4e0f2019-03-26 17:44:40 +0000144
145 y = z = wrapper::vsub(a, v1);
146 y = wrapper::vadd(y, v2);
147 z = wrapper::vsub(z, v2);
148}
149
150void fft_4(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3)
151{
152 float32x2_t a = x1;
153 float32x2_t b = c_mul_neon(w, x2);
154 float32x2_t c = c_mul_neon(w2, x3);
155 float32x2_t d = c_mul_neon(w3, x4);
156
157 const auto x11 = wrapper::vadd(a, b);
158 const auto x12 = wrapper::vadd(c, d);
159 x1 = wrapper::vadd(x11, x12);
160
161 const auto x21 = wrapper::vadd(a, c_mul_neon_img(b, -1));
162 const auto x22 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, 1.f));
163 x2 = wrapper::vadd(x21, x22);
164
165 const auto x31 = wrapper::vadd(a, wrapper::vneg(b));
166 const auto x32 = wrapper::vadd(c, wrapper::vneg(d));
167 x3 = wrapper::vadd(x31, x32);
168
169 const auto x41 = wrapper::vadd(a, c_mul_neon_img(b, 1));
170 const auto x42 = wrapper::vadd(wrapper::vneg(c), c_mul_neon_img(d, -1));
171 x4 = wrapper::vadd(x41, x42);
172}
173
giuros0114c4e0f2019-03-26 17:44:40 +0000174void fft_5(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3, const float32x2_t &w4)
175{
176 const auto a = x1;
177 const auto b = c_mul_neon(w, x2);
178 const auto c = c_mul_neon(w2, x3);
179 const auto d = c_mul_neon(w3, x4);
180 const auto e = c_mul_neon(w4, x5);
181
giuros0105fb4482019-03-26 17:44:40 +0000182 const auto b0 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, b);
183 const auto b1 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, b);
184 const auto b2 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, b);
185 const auto b3 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000186
giuros0105fb4482019-03-26 17:44:40 +0000187 const auto c0 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, c);
188 const auto c1 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, c);
189 const auto c2 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, c);
190 const auto c3 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, c);
giuros0114c4e0f2019-03-26 17:44:40 +0000191
giuros0105fb4482019-03-26 17:44:40 +0000192 const auto d0 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, d);
193 const auto d1 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, d);
194 const auto d2 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, d);
195 const auto d3 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000196
giuros0105fb4482019-03-26 17:44:40 +0000197 const auto e0 = c_mul_neon(float32x2_t{ kW5_0, kW5_1 }, e);
198 const auto e1 = c_mul_neon(float32x2_t{ -kW5_2, kW5_3 }, e);
199 const auto e2 = c_mul_neon(float32x2_t{ -kW5_2, -kW5_3 }, e);
200 const auto e3 = c_mul_neon(float32x2_t{ kW5_0, -kW5_1 }, e);
giuros0114c4e0f2019-03-26 17:44:40 +0000201
202 x1 = reduce_sum_5(a, b, c, d, e);
203 x2 = reduce_sum_5(a, b0, c0, d0, e0);
204 x3 = reduce_sum_5(a, b1, c1, d1, e1);
205 x4 = reduce_sum_5(a, b2, c2, d2, e2);
206 x5 = reduce_sum_5(a, b3, c3, d3, e3);
207}
208
giuros0114c4e0f2019-03-26 17:44:40 +0000209void fft_7(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, const float32x2_t &w, const float32x2_t &w2, const float32x2_t &w3,
210 const float32x2_t &w4,
211 const float32x2_t &w5, const float32x2_t &w6)
212{
213 const auto a = x1;
214 const auto b = c_mul_neon(w, x2);
215 const auto c = c_mul_neon(w2, x3);
216 const auto d = c_mul_neon(w3, x4);
217 const auto e = c_mul_neon(w4, x5);
218 const auto f = c_mul_neon(w5, x6);
219 const auto g = c_mul_neon(w6, x7);
220
giuros0105fb4482019-03-26 17:44:40 +0000221 const auto b0 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, b);
222 const auto b1 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, b);
223 const auto b2 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, b);
224 const auto b3 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, b);
225 const auto b4 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, b);
226 const auto b5 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000227
giuros0105fb4482019-03-26 17:44:40 +0000228 const auto c0 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, c);
229 const auto c1 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, c);
230 const auto c2 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, c);
231 const auto c3 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, c);
232 const auto c4 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, c);
233 const auto c5 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, c);
giuros0114c4e0f2019-03-26 17:44:40 +0000234
giuros0105fb4482019-03-26 17:44:40 +0000235 const auto d0 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, d);
236 const auto d1 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, d);
237 const auto d2 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, d);
238 const auto d3 = c_mul_neon(float32x2_t{ -kW7_2, +kW7_3 }, d);
239 const auto d4 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, d);
240 const auto d5 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000241
giuros0105fb4482019-03-26 17:44:40 +0000242 const auto e0 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, e);
243 const auto e1 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, e);
244 const auto e2 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, e);
245 const auto e3 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, e);
246 const auto e4 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, e);
247 const auto e5 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, e);
giuros0114c4e0f2019-03-26 17:44:40 +0000248
giuros0105fb4482019-03-26 17:44:40 +0000249 const auto f0 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, f);
250 const auto f1 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, f);
251 const auto f2 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, f);
252 const auto f3 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, f);
253 const auto f4 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, f);
254 const auto f5 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000255
giuros0105fb4482019-03-26 17:44:40 +0000256 const auto g0 = c_mul_neon(float32x2_t{ kW7_0, kW7_1 }, g);
257 const auto g1 = c_mul_neon(float32x2_t{ -kW7_2, kW7_3 }, g);
258 const auto g2 = c_mul_neon(float32x2_t{ -kW7_4, kW7_5 }, g);
259 const auto g3 = c_mul_neon(float32x2_t{ -kW7_4, -kW7_5 }, g);
260 const auto g4 = c_mul_neon(float32x2_t{ -kW7_2, -kW7_3 }, g);
261 const auto g5 = c_mul_neon(float32x2_t{ kW7_0, -kW7_1 }, g);
giuros0114c4e0f2019-03-26 17:44:40 +0000262
263 x1 = reduce_sum_7(a, b, c, d, e, f, g);
264 x2 = reduce_sum_7(a, b0, c0, d0, e0, f0, g0);
265 x3 = reduce_sum_7(a, b1, c1, d1, e1, f1, g1);
266 x4 = reduce_sum_7(a, b2, c2, d2, e2, f2, g2);
267 x5 = reduce_sum_7(a, b3, c3, d3, e3, f3, g3);
268 x6 = reduce_sum_7(a, b4, c4, d4, e4, f4, g4);
269 x7 = reduce_sum_7(a, b5, c5, d5, e5, f5, g5);
270}
271
giuros0114c4e0f2019-03-26 17:44:40 +0000272void fft_8(float32x2_t &x1, float32x2_t &x2, float32x2_t &x3, float32x2_t &x4, float32x2_t &x5, float32x2_t &x6, float32x2_t &x7, float32x2_t &x8, const float32x2_t &w, const float32x2_t &w2,
273 const float32x2_t &w3,
274 const float32x2_t &w4, const float32x2_t &w5, const float32x2_t &w6,
275 const float32x2_t &w7)
276{
277 const auto a = x1;
278 const auto b = c_mul_neon(w, x2);
279 const auto c = c_mul_neon(w2, x3);
280 const auto d = c_mul_neon(w3, x4);
281 const auto e = c_mul_neon(w4, x5);
282 const auto f = c_mul_neon(w5, x6);
283 const auto g = c_mul_neon(w6, x7);
284 const auto h = c_mul_neon(w7, x8);
285
giuros0105fb4482019-03-26 17:44:40 +0000286 const auto b0 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000287 const auto b1 = c_mul_neon(float32x2_t{ 0, -1 }, b);
giuros0105fb4482019-03-26 17:44:40 +0000288 const auto b2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000289 const auto b3 = c_mul_neon(float32x2_t{ -1, 0 }, b);
giuros0105fb4482019-03-26 17:44:40 +0000290 const auto b4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000291 const auto b5 = c_mul_neon(float32x2_t{ 0, 1 }, b);
giuros0105fb4482019-03-26 17:44:40 +0000292 const auto b6 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, b);
giuros0114c4e0f2019-03-26 17:44:40 +0000293
294 const auto c0 = c_mul_neon(float32x2_t{ 0, -1 }, c);
295 const auto c1 = c_mul_neon(float32x2_t{ -1, 0 }, c);
296 const auto c2 = c_mul_neon(float32x2_t{ 0, 1 }, c);
297 const auto c3 = c_mul_neon(float32x2_t{ 1, 0 }, c);
298 const auto c4 = c_mul_neon(float32x2_t{ 0, -1 }, c);
299 const auto c5 = c_mul_neon(float32x2_t{ -1, 0 }, c);
300 const auto c6 = c_mul_neon(float32x2_t{ 0, 1 }, c);
301
giuros0105fb4482019-03-26 17:44:40 +0000302 const auto d0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000303 const auto d1 = c_mul_neon(float32x2_t{ 0, 1 }, d);
giuros0105fb4482019-03-26 17:44:40 +0000304 const auto d2 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000305 const auto d3 = c_mul_neon(float32x2_t{ -1, 0 }, d);
giuros0105fb4482019-03-26 17:44:40 +0000306 const auto d4 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000307 const auto d5 = c_mul_neon(float32x2_t{ 0, -1 }, d);
giuros0105fb4482019-03-26 17:44:40 +0000308 const auto d6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, d);
giuros0114c4e0f2019-03-26 17:44:40 +0000309
310 const auto e0 = c_mul_neon(float32x2_t{ -1, 0 }, e);
311 const auto e1 = c_mul_neon(float32x2_t{ 1, 0 }, e);
312 const auto e2 = c_mul_neon(float32x2_t{ -1, 0 }, e);
313 const auto e3 = c_mul_neon(float32x2_t{ 1, 0 }, e);
314 const auto e4 = c_mul_neon(float32x2_t{ -1, 0 }, e);
315 const auto e5 = c_mul_neon(float32x2_t{ 1, 0 }, e);
316 const auto e6 = c_mul_neon(float32x2_t{ -1, 0 }, e);
317
giuros0105fb4482019-03-26 17:44:40 +0000318 const auto f0 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000319 const auto f1 = c_mul_neon(float32x2_t{ 0, -1 }, f);
giuros0105fb4482019-03-26 17:44:40 +0000320 const auto f2 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000321 const auto f3 = c_mul_neon(float32x2_t{ -1, 0 }, f);
giuros0105fb4482019-03-26 17:44:40 +0000322 const auto f4 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000323 const auto f5 = c_mul_neon(float32x2_t{ 0, 1 }, f);
giuros0105fb4482019-03-26 17:44:40 +0000324 const auto f6 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, f);
giuros0114c4e0f2019-03-26 17:44:40 +0000325
326 const auto g0 = c_mul_neon(float32x2_t{ 0, 1 }, g);
327 const auto g1 = c_mul_neon(float32x2_t{ -1, 0 }, g);
328 const auto g2 = c_mul_neon(float32x2_t{ 0, -1 }, g);
329 const auto g3 = c_mul_neon(float32x2_t{ 1, 0 }, g);
330 const auto g4 = c_mul_neon(float32x2_t{ 0, 1 }, g);
331 const auto g5 = c_mul_neon(float32x2_t{ -1, 0 }, g);
332 const auto g6 = c_mul_neon(float32x2_t{ 0, -1 }, g);
333
giuros0105fb4482019-03-26 17:44:40 +0000334 const auto h0 = c_mul_neon(float32x2_t{ kSqrt2Div2, kSqrt2Div2 }, h);
giuros0114c4e0f2019-03-26 17:44:40 +0000335 const auto h1 = c_mul_neon(float32x2_t{ 0, 1 }, h);
giuros0105fb4482019-03-26 17:44:40 +0000336 const auto h2 = c_mul_neon(float32x2_t{ -kSqrt2Div2, kSqrt2Div2 }, h);
giuros0114c4e0f2019-03-26 17:44:40 +0000337 const auto h3 = c_mul_neon(float32x2_t{ -1, 0 }, h);
giuros0105fb4482019-03-26 17:44:40 +0000338 const auto h4 = c_mul_neon(float32x2_t{ -kSqrt2Div2, -kSqrt2Div2 }, h);
giuros0114c4e0f2019-03-26 17:44:40 +0000339 const auto h5 = c_mul_neon(float32x2_t{ 0, -1 }, h);
giuros0105fb4482019-03-26 17:44:40 +0000340 const auto h6 = c_mul_neon(float32x2_t{ kSqrt2Div2, -kSqrt2Div2 }, h);
giuros0114c4e0f2019-03-26 17:44:40 +0000341
342 x1 = reduce_sum_8(a, b, c, d, e, f, g, h);
343 x2 = reduce_sum_8(a, b0, c0, d0, e0, f0, g0, h0);
344 x3 = reduce_sum_8(a, b1, c1, d1, e1, f1, g1, h1);
345 x4 = reduce_sum_8(a, b2, c2, d2, e2, f2, g2, h2);
346 x5 = reduce_sum_8(a, b3, c3, d3, e3, f3, g3, h3);
347 x6 = reduce_sum_8(a, b4, c4, d4, e4, f4, g4, h4);
348 x7 = reduce_sum_8(a, b5, c5, d5, e5, f5, g5, h5);
349 x8 = reduce_sum_8(a, b6, c6, d6, e6, f6, g6, h6);
350}
351
352template <bool first_stage>
giuros0105fb4482019-03-26 17:44:40 +0000353void fft_radix_2_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000354{
giuros0105fb4482019-03-26 17:44:40 +0000355 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000356 for(unsigned int j = 0; j < Nx; j++)
357 {
giuros0105fb4482019-03-26 17:44:40 +0000358 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000359 {
360 auto a = float32x2_t{ 0, 0 };
361 auto b = float32x2_t{ 0, 0 };
362
363 // Load inputs
364 if(first_stage)
365 {
366 const auto ab = wrapper::vloadq(x + k);
367 a = wrapper::vgetlow(ab);
368 b = wrapper::vgethigh(ab);
369 }
370 else
371 {
372 a = wrapper::vload(x + k);
373 b = wrapper::vload(x + k + 2 * Nx);
374 }
375
376 // Base-case prime transform
377 fft_2(a, b, w);
378
379 // Write outputs
380 if(first_stage)
381 {
382 wrapper::vstore(X + k, wrapper::vcombine(a, b));
383 }
384 else
385 {
386 wrapper::vstore(X + k, a);
387 wrapper::vstore(X + k + 2 * Nx, b);
388 }
389 }
390
391 w = c_mul_neon(w, w_m);
392 }
393}
394
giuros0105fb4482019-03-26 17:44:40 +0000395void fft_radix_2_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000396{
giuros0105fb4482019-03-26 17:44:40 +0000397 float32x2_t w{ 1.0f, 0.0f };
398 for(unsigned int j = 0; j < Nx; j++)
399 {
400 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
401 {
402 // Load inputs
403 float32x2_t a = wrapper::vload(x + M * k);
404 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000405
giuros0105fb4482019-03-26 17:44:40 +0000406 // Base-case prime transform
407 fft_2(a, b, w);
408
409 // Write outputs
410 wrapper::vstore(X + M * k, a);
411 wrapper::vstore(X + M * (k + 2 * Nx), b);
412 }
413
414 w = c_mul_neon(w, w_m);
415 }
416}
417
418template <bool first_stage>
419void fft_radix_3_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
420{
421 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000422 for(unsigned int j = 0; j < Nx; j++)
423 {
424 const auto w2 = c_mul_neon(w, w);
425
giuros0105fb4482019-03-26 17:44:40 +0000426 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000427 {
428 // Load inputs
429 float32x2_t a = { 0, 0 };
430 float32x2_t b = { 0, 0 };
431 float32x2_t c = { 0, 0 };
432 if(first_stage)
433 {
434 const auto ab = wrapper::vloadq(x + k);
435 a = wrapper::vgetlow(ab);
436 b = wrapper::vgethigh(ab);
437 }
438 else
439 {
440 a = wrapper::vload(x + k);
441 b = wrapper::vload(x + k + 2 * Nx);
442 }
443 c = wrapper::vload(x + k + 4 * Nx);
444
445 // Base-case prime transform
446 fft_3(a, b, c, w, w2);
447
448 if(first_stage)
449 {
450 wrapper::vstore(X + k, wrapper::vcombine(a, b));
451 }
452 else
453 {
454 wrapper::vstore(X + k, a);
455 wrapper::vstore(X + k + 2 * Nx, b);
456 }
457 wrapper::vstore(X + k + 4 * Nx, c);
458 }
459 w = c_mul_neon(w, w_m);
460 }
461}
462
giuros0105fb4482019-03-26 17:44:40 +0000463void fft_radix_3_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000464{
giuros0105fb4482019-03-26 17:44:40 +0000465 float32x2_t w{ 1.0f, 0.0f };
466 for(unsigned int j = 0; j < Nx; j++)
467 {
468 const auto w2 = c_mul_neon(w, w);
giuros0114c4e0f2019-03-26 17:44:40 +0000469
giuros0105fb4482019-03-26 17:44:40 +0000470 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
471 {
472 // Load inputs
473 float32x2_t a = wrapper::vload(x + M * k);
474 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
475 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000476
giuros0105fb4482019-03-26 17:44:40 +0000477 // Base-case prime transform
478 fft_3(a, b, c, w, w2);
479
480 // Store the output
481 wrapper::vstore(X + M * k, a);
482 wrapper::vstore(X + M * (k + 2 * Nx), b);
483 wrapper::vstore(X + M * (k + 4 * Nx), c);
484 }
485 w = c_mul_neon(w, w_m);
486 }
487}
488
489template <bool first_stage>
490void fft_radix_4_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
491{
492 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000493 for(unsigned int j = 0; j < Nx; j++)
494 {
495 const auto w2 = c_mul_neon(w, w);
496 const auto w3 = c_mul_neon(w2, w);
497
giuros0105fb4482019-03-26 17:44:40 +0000498 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000499 {
500 float32x2_t a = { 0, 0 };
501 float32x2_t b = { 0, 0 };
502 float32x2_t c = { 0, 0 };
503 float32x2_t d = { 0, 0 };
504 if(first_stage)
505 {
506 const auto ab = wrapper::vloadq(x + k);
507 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
508 a = wrapper::vgetlow(ab);
509 b = wrapper::vgethigh(ab);
510 c = wrapper::vgetlow(cd);
511 d = wrapper::vgethigh(cd);
512 }
513 else
514 {
515 // Load inputs
516 a = wrapper::vload(x + k);
517 b = wrapper::vload(x + k + 2 * Nx);
518 c = wrapper::vload(x + k + 4 * Nx);
519 d = wrapper::vload(x + k + 6 * Nx);
520 }
521
522 // Base-case prime transform
523 fft_4(a, b, c, d, w, w2, w3);
524
525 if(first_stage)
526 {
527 wrapper::vstore(X + k, wrapper::vcombine(a, b));
528 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
529 }
530 else
531 {
532 wrapper::vstore(X + k, a);
533 wrapper::vstore(X + k + 2 * Nx, b);
534 wrapper::vstore(X + k + 4 * Nx, c);
535 wrapper::vstore(X + k + 6 * Nx, d);
536 }
537 }
538
539 w = c_mul_neon(w, w_m);
540 }
541}
542
giuros0105fb4482019-03-26 17:44:40 +0000543void fft_radix_4_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000544{
giuros0105fb4482019-03-26 17:44:40 +0000545 float32x2_t w{ 1.0f, 0.0f };
546 for(unsigned int j = 0; j < Nx; j++)
547 {
548 const auto w2 = c_mul_neon(w, w);
549 const auto w3 = c_mul_neon(w2, w);
giuros0114c4e0f2019-03-26 17:44:40 +0000550
giuros0105fb4482019-03-26 17:44:40 +0000551 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
552 {
553 // Load inputs
554 float32x2_t a = wrapper::vload(x + M * k);
555 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
556 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
557 float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000558
giuros0105fb4482019-03-26 17:44:40 +0000559 // Base-case prime transform
560 fft_4(a, b, c, d, w, w2, w3);
561
562 wrapper::vstore(X + M * k, a);
563 wrapper::vstore(X + M * (k + 2 * Nx), b);
564 wrapper::vstore(X + M * (k + 4 * Nx), c);
565 wrapper::vstore(X + M * (k + 6 * Nx), d);
566 }
567
568 w = c_mul_neon(w, w_m);
569 }
570}
571
572template <bool first_stage>
573void fft_radix_5_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
574{
575 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000576 for(unsigned int j = 0; j < Nx; j++)
577 {
578 const float32x2_t w2 = c_mul_neon(w, w);
579 const float32x2_t w3 = c_mul_neon(w2, w);
580 const float32x2_t w4 = c_mul_neon(w3, w);
581
giuros0105fb4482019-03-26 17:44:40 +0000582 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000583 {
584 float32x2_t a = { 0, 0 };
585 float32x2_t b = { 0, 0 };
586 float32x2_t c = { 0, 0 };
587 float32x2_t d = { 0, 0 };
588 float32x2_t e = { 0, 0 };
589
590 // Load inputs
591 if(first_stage)
592 {
593 const auto ab = wrapper::vloadq(x + k);
594 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
595
596 a = wrapper::vgetlow(ab);
597 b = wrapper::vgethigh(ab);
598 c = wrapper::vgetlow(cd);
599 d = wrapper::vgethigh(cd);
600 }
601 else
602 {
603 a = wrapper::vload(x + k);
604 b = wrapper::vload(x + k + 2 * Nx);
605 c = wrapper::vload(x + k + 4 * Nx);
606 d = wrapper::vload(x + k + 6 * Nx);
607 }
608 e = wrapper::vload(x + k + 8 * Nx);
609
610 // Base-case prime transform
611 fft_5(a, b, c, d, e, w, w2, w3, w4);
612
613 // Store outputs
614 if(first_stage)
615 {
616 wrapper::vstore(X + k, wrapper::vcombine(a, b));
617 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
618 }
619 else
620 {
621 wrapper::vstore(X + k, a);
622 wrapper::vstore(X + k + 2 * Nx, b);
623 wrapper::vstore(X + k + 4 * Nx, c);
624 wrapper::vstore(X + k + 6 * Nx, d);
625 }
626 wrapper::vstore(X + k + 8 * Nx, e);
627 }
628
629 w = c_mul_neon(w, w_m);
630 }
631}
632
giuros0105fb4482019-03-26 17:44:40 +0000633void fft_radix_5_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000634{
giuros0105fb4482019-03-26 17:44:40 +0000635 float32x2_t w{ 1.0f, 0.0f };
636 for(unsigned int j = 0; j < Nx; j++)
637 {
638 const float32x2_t w2 = c_mul_neon(w, w);
639 const float32x2_t w3 = c_mul_neon(w2, w);
640 const float32x2_t w4 = c_mul_neon(w3, w);
giuros0114c4e0f2019-03-26 17:44:40 +0000641
giuros0105fb4482019-03-26 17:44:40 +0000642 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
643 {
644 // Load inputs
645 float32x2_t a = wrapper::vload(x + M * k);
646 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
647 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
648 float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
649 float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000650
giuros0105fb4482019-03-26 17:44:40 +0000651 // Base-case prime transform
652 fft_5(a, b, c, d, e, w, w2, w3, w4);
653
654 // Store outputs
655 wrapper::vstore(X + M * k, a);
656 wrapper::vstore(X + M * (k + 2 * Nx), b);
657 wrapper::vstore(X + M * (k + 4 * Nx), c);
658 wrapper::vstore(X + M * (k + 6 * Nx), d);
659 wrapper::vstore(X + M * (k + 8 * Nx), e);
660 }
661
662 w = c_mul_neon(w, w_m);
663 }
664}
665
666template <bool first_stage>
667void fft_radix_7_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
668{
669 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000670 for(unsigned int j = 0; j < Nx; j++)
671 {
672 const float32x2_t w2 = c_mul_neon(w, w);
673 const float32x2_t w3 = c_mul_neon(w2, w);
674 const float32x2_t w4 = c_mul_neon(w3, w);
675 const float32x2_t w5 = c_mul_neon(w4, w);
676 const float32x2_t w6 = c_mul_neon(w5, w);
677
giuros0105fb4482019-03-26 17:44:40 +0000678 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000679 {
680 float32x2_t a = { 0, 0 };
681 float32x2_t b = { 0, 0 };
682 float32x2_t c = { 0, 0 };
683 float32x2_t d = { 0, 0 };
684 float32x2_t e = { 0, 0 };
685 float32x2_t f = { 0, 0 };
686 float32x2_t g = { 0, 0 };
687
688 // Load inputs
689 if(first_stage)
690 {
691 const auto ab = wrapper::vloadq(x + k);
692 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
693 const auto ef = wrapper::vloadq(x + k + 8 * Nx);
694
695 a = wrapper::vgetlow(ab);
696 b = wrapper::vgethigh(ab);
697 c = wrapper::vgetlow(cd);
698 d = wrapper::vgethigh(cd);
699 e = wrapper::vgetlow(ef);
700 f = wrapper::vgethigh(ef);
701 }
702 else
703 {
704 a = wrapper::vload(x + k);
705 b = wrapper::vload(x + k + 2 * Nx);
706 c = wrapper::vload(x + k + 4 * Nx);
707 d = wrapper::vload(x + k + 6 * Nx);
708 e = wrapper::vload(x + k + 8 * Nx);
709 f = wrapper::vload(x + k + 10 * Nx);
710 }
711 g = wrapper::vload(x + k + 12 * Nx);
712
713 // Base-case prime transform
714 fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
715
716 if(first_stage)
717 {
718 wrapper::vstore(X + k, wrapper::vcombine(a, b));
719 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
720 wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
721 }
722 else
723 {
724 wrapper::vstore(X + k, a);
725 wrapper::vstore(X + k + 2 * Nx, b);
726 wrapper::vstore(X + k + 4 * Nx, c);
727 wrapper::vstore(X + k + 6 * Nx, d);
728 wrapper::vstore(X + k + 8 * Nx, e);
729 wrapper::vstore(X + k + 10 * Nx, f);
730 }
731 wrapper::vstore(X + k + 12 * Nx, g);
732 }
733
734 w = c_mul_neon(w, w_m);
735 }
736}
737
giuros0105fb4482019-03-26 17:44:40 +0000738void fft_radix_7_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
giuros0114c4e0f2019-03-26 17:44:40 +0000739{
giuros0105fb4482019-03-26 17:44:40 +0000740 float32x2_t w{ 1.0f, 0.0f };
741 for(unsigned int j = 0; j < Nx; j++)
742 {
743 const float32x2_t w2 = c_mul_neon(w, w);
744 const float32x2_t w3 = c_mul_neon(w2, w);
745 const float32x2_t w4 = c_mul_neon(w3, w);
746 const float32x2_t w5 = c_mul_neon(w4, w);
747 const float32x2_t w6 = c_mul_neon(w5, w);
giuros0114c4e0f2019-03-26 17:44:40 +0000748
giuros0105fb4482019-03-26 17:44:40 +0000749 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
750 {
751 // Load inputs
752 float32x2_t a = wrapper::vload(x + M * k);
753 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
754 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
755 float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
756 float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
757 float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
758 float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
giuros0114c4e0f2019-03-26 17:44:40 +0000759
giuros0105fb4482019-03-26 17:44:40 +0000760 // Base-case prime transform
761 fft_7(a, b, c, d, e, f, g, w, w2, w3, w4, w5, w6);
762
763 // Store outputs
764 wrapper::vstore(X + M * k, a);
765 wrapper::vstore(X + M * (k + 2 * Nx), b);
766 wrapper::vstore(X + M * (k + 4 * Nx), c);
767 wrapper::vstore(X + M * (k + 6 * Nx), d);
768 wrapper::vstore(X + M * (k + 8 * Nx), e);
769 wrapper::vstore(X + M * (k + 10 * Nx), f);
770 wrapper::vstore(X + M * (k + 12 * Nx), g);
771 }
772
773 w = c_mul_neon(w, w_m);
774 }
775}
776
777template <bool first_stage>
778void fft_radix_8_axes_0(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int N)
779{
780 float32x2_t w{ 1.0f, 0.0f };
giuros0114c4e0f2019-03-26 17:44:40 +0000781 for(unsigned int j = 0; j < Nx; j++)
782 {
783 const float32x2_t w2 = c_mul_neon(w, w);
784 const float32x2_t w3 = c_mul_neon(w2, w);
785 const float32x2_t w4 = c_mul_neon(w3, w);
786 const float32x2_t w5 = c_mul_neon(w4, w);
787 const float32x2_t w6 = c_mul_neon(w5, w);
788 const float32x2_t w7 = c_mul_neon(w6, w);
789
giuros0105fb4482019-03-26 17:44:40 +0000790 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
giuros0114c4e0f2019-03-26 17:44:40 +0000791 {
792 // Load inputs
793 float32x2_t a = { 0, 0 };
794 float32x2_t b = { 0, 0 };
795 float32x2_t c = { 0, 0 };
796 float32x2_t d = { 0, 0 };
797 float32x2_t e = { 0, 0 };
798 float32x2_t f = { 0, 0 };
799 float32x2_t g = { 0, 0 };
800 float32x2_t h = { 0, 0 };
801
802 // Base-case prime transform
803 if(first_stage)
804 {
805 const auto ab = wrapper::vloadq(x + k);
806 const auto cd = wrapper::vloadq(x + k + 4 * Nx);
807 const auto ef = wrapper::vloadq(x + k + 8 * Nx);
808 const auto gh = wrapper::vloadq(x + k + 12 * Nx);
809
810 a = wrapper::vgetlow(ab);
811 b = wrapper::vgethigh(ab);
812 c = wrapper::vgetlow(cd);
813 d = wrapper::vgethigh(cd);
814 e = wrapper::vgetlow(ef);
815 f = wrapper::vgethigh(ef);
816 g = wrapper::vgetlow(gh);
817 h = wrapper::vgethigh(gh);
818 }
819 else
820 {
821 a = wrapper::vload(x + k);
822 b = wrapper::vload(x + k + 2 * Nx);
823 c = wrapper::vload(x + k + 4 * Nx);
824 d = wrapper::vload(x + k + 6 * Nx);
825 e = wrapper::vload(x + k + 8 * Nx);
826 f = wrapper::vload(x + k + 10 * Nx);
827 g = wrapper::vload(x + k + 12 * Nx);
828 h = wrapper::vload(x + k + 14 * Nx);
829 }
830
831 // Apply twiddle factors
832 fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
833
834 // Store outputs
835 if(first_stage)
836 {
837 wrapper::vstore(X + k, wrapper::vcombine(a, b));
838 wrapper::vstore(X + k + 4 * Nx, wrapper::vcombine(c, d));
839 wrapper::vstore(X + k + 8 * Nx, wrapper::vcombine(e, f));
840 wrapper::vstore(X + k + 12 * Nx, wrapper::vcombine(g, h));
841 }
842 else
843 {
844 wrapper::vstore(X + k, a);
845 wrapper::vstore(X + k + 2 * Nx, b);
846 wrapper::vstore(X + k + 4 * Nx, c);
847 wrapper::vstore(X + k + 6 * Nx, d);
848 wrapper::vstore(X + k + 8 * Nx, e);
849 wrapper::vstore(X + k + 10 * Nx, f);
850 wrapper::vstore(X + k + 12 * Nx, g);
851 wrapper::vstore(X + k + 14 * Nx, h);
852 }
853 }
854
855 w = c_mul_neon(w, w_m);
856 }
857}
858
giuros0105fb4482019-03-26 17:44:40 +0000859void fft_radix_8_axes_1(float *X, float *x, unsigned int Nx, unsigned int NxRadix, const float32x2_t &w_m, unsigned int M, unsigned int N)
860{
861 float32x2_t w{ 1.0f, 0.0f };
862 for(unsigned int j = 0; j < Nx; j++)
863 {
864 const float32x2_t w2 = c_mul_neon(w, w);
865 const float32x2_t w3 = c_mul_neon(w2, w);
866 const float32x2_t w4 = c_mul_neon(w3, w);
867 const float32x2_t w5 = c_mul_neon(w4, w);
868 const float32x2_t w6 = c_mul_neon(w5, w);
869 const float32x2_t w7 = c_mul_neon(w6, w);
870
871 for(unsigned int k = 2 * j; k < 2 * N; k += 2 * NxRadix)
872 {
873 // Load inputs
874 float32x2_t a = wrapper::vload(x + M * k);
875 float32x2_t b = wrapper::vload(x + M * (k + 2 * Nx));
876 float32x2_t c = wrapper::vload(x + M * (k + 4 * Nx));
877 float32x2_t d = wrapper::vload(x + M * (k + 6 * Nx));
878 float32x2_t e = wrapper::vload(x + M * (k + 8 * Nx));
879 float32x2_t f = wrapper::vload(x + M * (k + 10 * Nx));
880 float32x2_t g = wrapper::vload(x + M * (k + 12 * Nx));
881 float32x2_t h = wrapper::vload(x + M * (k + 14 * Nx));
882
883 // Base-case prime transform
884 fft_8(a, b, c, d, e, f, g, h, w, w2, w3, w4, w5, w6, w7);
885
886 // Store outputs
887 wrapper::vstore(X + M * k, a);
888 wrapper::vstore(X + M * (k + 2 * Nx), b);
889 wrapper::vstore(X + M * (k + 4 * Nx), c);
890 wrapper::vstore(X + M * (k + 6 * Nx), d);
891 wrapper::vstore(X + M * (k + 8 * Nx), e);
892 wrapper::vstore(X + M * (k + 10 * Nx), f);
893 wrapper::vstore(X + M * (k + 12 * Nx), g);
894 wrapper::vstore(X + M * (k + 14 * Nx), h);
895 }
896
897 w = c_mul_neon(w, w_m);
898 }
899}
900
giuros0114c4e0f2019-03-26 17:44:40 +0000901Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
902{
903 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 2, DataType::F32);
giuros0105fb4482019-03-26 17:44:40 +0000904 ARM_COMPUTE_RETURN_ERROR_ON(config.axis > 1);
giuros0114c4e0f2019-03-26 17:44:40 +0000905 ARM_COMPUTE_RETURN_ERROR_ON(NEFFTRadixStageKernel::supported_radix().count(config.radix) == 0);
giuros0105fb4482019-03-26 17:44:40 +0000906 ARM_COMPUTE_UNUSED(config);
giuros0114c4e0f2019-03-26 17:44:40 +0000907
908 // Checks performed when output is configured
909 if((output != nullptr) && (output->total_size() != 0))
910 {
911 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input, output);
912 ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output);
913 }
914
915 return Status{};
916}
917
918std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output, const FFTRadixStageKernelInfo &config)
919{
giuros0105fb4482019-03-26 17:44:40 +0000920 ARM_COMPUTE_UNUSED(config);
921
giuros0114c4e0f2019-03-26 17:44:40 +0000922 if(output != nullptr)
923 {
924 auto_init_if_empty(*output, *input);
925 }
926
giuros0105fb4482019-03-26 17:44:40 +0000927 Window win = calculate_max_window(*input, Steps());
giuros0114c4e0f2019-03-26 17:44:40 +0000928 if(output != nullptr)
929 {
930 output->set_valid_region(ValidRegion(Coordinates(), output->tensor_shape()));
931 }
932
933 return std::make_pair(Status{}, win);
934}
935} // namespace
936
937NEFFTRadixStageKernel::NEFFTRadixStageKernel()
giuros0105fb4482019-03-26 17:44:40 +0000938 : _input(nullptr), _output(nullptr), _run_in_place(false), _Nx(0), _axis(0), _radix(0), _func_0(), _func_1()
giuros0114c4e0f2019-03-26 17:44:40 +0000939{
940}
941
giuros0105fb4482019-03-26 17:44:40 +0000942void NEFFTRadixStageKernel::set_radix_stage_axis0(const FFTRadixStageKernelInfo &config)
giuros0114c4e0f2019-03-26 17:44:40 +0000943{
giuros0105fb4482019-03-26 17:44:40 +0000944 // FFT table axis 0: [radix, first_stage]
945 static std::map<unsigned int, std::map<bool, FFTFunctionPointerAxis0>> fft_table_axis0;
946
947 if(fft_table_axis0.empty())
giuros0114c4e0f2019-03-26 17:44:40 +0000948 {
giuros0105fb4482019-03-26 17:44:40 +0000949 fft_table_axis0[2][false] = &fft_radix_2_axes_0<false>;
950 fft_table_axis0[3][false] = &fft_radix_3_axes_0<false>;
951 fft_table_axis0[4][false] = &fft_radix_4_axes_0<false>;
952 fft_table_axis0[5][false] = &fft_radix_5_axes_0<false>;
953 fft_table_axis0[7][false] = &fft_radix_7_axes_0<false>;
954 fft_table_axis0[8][false] = &fft_radix_8_axes_0<false>;
955
956 fft_table_axis0[2][true] = &fft_radix_2_axes_0<true>;
957 fft_table_axis0[3][true] = &fft_radix_3_axes_0<true>;
958 fft_table_axis0[4][true] = &fft_radix_4_axes_0<true>;
959 fft_table_axis0[5][true] = &fft_radix_5_axes_0<true>;
960 fft_table_axis0[7][true] = &fft_radix_7_axes_0<true>;
961 fft_table_axis0[8][true] = &fft_radix_8_axes_0<true>;
giuros0114c4e0f2019-03-26 17:44:40 +0000962 }
giuros0105fb4482019-03-26 17:44:40 +0000963
964 _func_0 = fft_table_axis0[config.radix][config.is_first_stage];
965}
966
967void NEFFTRadixStageKernel::set_radix_stage_axis1(const FFTRadixStageKernelInfo &config)
968{
969 // FFT table axis 1: [radix, first_stage]
970 static std::map<unsigned int, FFTFunctionPointerAxis1> fft_table_axis1;
971
972 if(fft_table_axis1.empty())
973 {
974 fft_table_axis1[2] = &fft_radix_2_axes_1;
975 fft_table_axis1[3] = &fft_radix_3_axes_1;
976 fft_table_axis1[4] = &fft_radix_4_axes_1;
977 fft_table_axis1[5] = &fft_radix_5_axes_1;
978 fft_table_axis1[7] = &fft_radix_7_axes_1;
979 fft_table_axis1[8] = &fft_radix_8_axes_1;
980 }
981
982 _func_1 = fft_table_axis1[config.radix];
giuros0114c4e0f2019-03-26 17:44:40 +0000983}
984
985void NEFFTRadixStageKernel::configure(ITensor *input, ITensor *output, const FFTRadixStageKernelInfo &config)
986{
987 ARM_COMPUTE_ERROR_ON_NULLPTR(input);
988
989 // Output auto inizialitation if not yet initialized
990 if(output != nullptr)
991 {
992 auto_init_if_empty(*output->info(), *input->info()->clone());
993 }
994
995 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, config));
996
997 _input = input;
998 _output = output;
999 _run_in_place = (output == nullptr) || (output == input);
1000 _Nx = config.Nx;
giuros0105fb4482019-03-26 17:44:40 +00001001 _axis = config.axis;
1002 _radix = config.radix;
giuros0114c4e0f2019-03-26 17:44:40 +00001003
giuros0105fb4482019-03-26 17:44:40 +00001004 switch(config.axis)
giuros0114c4e0f2019-03-26 17:44:40 +00001005 {
giuros0105fb4482019-03-26 17:44:40 +00001006 case 0:
1007 set_radix_stage_axis0(config);
1008 break;
1009 case 1:
1010 set_radix_stage_axis1(config);
1011 break;
1012 default:
1013 ARM_COMPUTE_ERROR("Axis not supported");
1014 break;
giuros0114c4e0f2019-03-26 17:44:40 +00001015 }
1016
1017 // Configure kernel window
1018 auto win_config = validate_and_configure_window(input->info(), (_run_in_place) ? nullptr : output->info(), config);
1019 ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
1020 INEKernel::configure(win_config.second);
1021}
1022
1023Status NEFFTRadixStageKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const FFTRadixStageKernelInfo &config)
1024{
1025 const bool run_in_place = (output == nullptr) || (output == input);
1026 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input, output, config));
1027 ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input->clone().get(),
1028 (run_in_place) ? nullptr : output->clone().get(),
1029 config)
1030 .first);
1031
1032 return Status{};
1033}
1034
1035std::set<unsigned int> NEFFTRadixStageKernel::supported_radix()
1036{
1037 return std::set<unsigned int> { 2, 3, 4, 5, 7, 8 };
1038}
1039
1040void NEFFTRadixStageKernel::run(const Window &window, const ThreadInfo &info)
1041{
giuros0114c4e0f2019-03-26 17:44:40 +00001042 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1043 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
giuros0105fb4482019-03-26 17:44:40 +00001044 ARM_COMPUTE_UNUSED(info);
giuros0114c4e0f2019-03-26 17:44:40 +00001045
1046 Window input_window = window;
giuros0105fb4482019-03-26 17:44:40 +00001047 input_window.set(_axis, 0);
giuros0114c4e0f2019-03-26 17:44:40 +00001048
1049 Iterator in(_input, input_window);
1050 Iterator out(_run_in_place ? _input : _output, input_window);
1051
giuros0105fb4482019-03-26 17:44:40 +00001052 // Precompute FFT constants
1053 const unsigned int NxRadix = _radix * _Nx;
1054 const float alpha = 2.0f * kPi / float(NxRadix);
1055 const float32x2_t w_m{ cosf(alpha), -sinf(alpha) };
1056
1057 if(_axis == 0)
giuros0114c4e0f2019-03-26 17:44:40 +00001058 {
giuros0105fb4482019-03-26 17:44:40 +00001059 const unsigned int N = _input->info()->dimension(0);
1060 execute_window_loop(input_window, [&](const Coordinates &)
1061 {
1062 _func_0(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N);
1063 },
1064 in, out);
1065 }
1066 else
1067 {
1068 const unsigned int N = _input->info()->dimension(0);
1069 const unsigned int M = _input->info()->dimension(1);
1070 execute_window_loop(input_window, [&](const Coordinates &)
1071 {
1072 _func_1(reinterpret_cast<float *>(out.ptr()), reinterpret_cast<float *>(in.ptr()), _Nx, NxRadix, w_m, N, M);
1073 },
1074 in, out);
1075 }
giuros0114c4e0f2019-03-26 17:44:40 +00001076
1077 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
1078 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
1079}
1080} // namespace arm_compute