blob: 26b4b72c0275a5fcd3e7b1be073431dc5ce01c4b [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "PlatformMath.hpp"
Kshitij Sisodia76a15802021-12-24 11:05:11 +000018#include <algorithm>
19#include <numeric>
alexander3c798932021-03-26 21:42:19 +000020
21#if 0 == ARM_DSP_AVAILABLE
22 #include <cmath>
23 #include <numeric>
24#endif /* 0 == ARM_DSP_AVAILABLE */
25
26namespace arm {
27namespace app {
28namespace math {
29
30 float MathUtils::CosineF32(float radians)
31 {
32#if ARM_DSP_AVAILABLE
33 return arm_cos_f32(radians);
34#else /* ARM_DSP_AVAILABLE */
35 return cos(radians);
36#endif /* ARM_DSP_AVAILABLE */
37 }
38
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +010039 float MathUtils::SineF32(float radians)
40 {
41#if ARM_DSP_AVAILABLE
42 return arm_sin_f32(radians);
43#else /* ARM_DSP_AVAILABLE */
44 return sin(radians);
45#endif /* ARM_DSP_AVAILABLE */
46 }
47
alexander3c798932021-03-26 21:42:19 +000048 float MathUtils::SqrtF32(float input)
49 {
50#if ARM_DSP_AVAILABLE
51 float output = 0.f;
52 arm_sqrt_f32(input, &output);
53 return output;
54#else /* ARM_DSP_AVAILABLE */
55 return sqrtf(input);
56#endif /* ARM_DSP_AVAILABLE */
57 }
58
59 float MathUtils::MeanF32(float* ptrSrc, const uint32_t srcLen)
60 {
61 if (!srcLen) {
62 return 0.f;
63 }
64
65#if ARM_DSP_AVAILABLE
66 float result = 0.f;
67 arm_mean_f32(ptrSrc, srcLen, &result);
68 return result;
69#else /* ARM_DSP_AVAILABLE */
70 float acc = std::accumulate(ptrSrc, ptrSrc + srcLen, 0.0);
71 return acc/srcLen;
72#endif /* ARM_DSP_AVAILABLE */
73 }
74
75 float MathUtils::StdDevF32(float* ptrSrc, const uint32_t srcLen,
76 const float mean)
77 {
78 if (!srcLen) {
79 return 0.f;
80 }
81#if ARM_DSP_AVAILABLE
82 /**
83 * Note Standard deviation calculation can be off
84 * by > 0.01 but less than < 0.1, according to
85 * preliminary findings.
86 **/
87 UNUSED(mean);
88 float stdDev = 0;
89 arm_std_f32(ptrSrc, srcLen, &stdDev);
90 return stdDev;
91#else /* ARM_DSP_AVAILABLE */
92 auto VarianceFunction = [=](float acc, const float value) {
93 return acc + (((value - mean) * (value - mean))/ srcLen);
94 };
95
96 float acc = std::accumulate(ptrSrc, ptrSrc + srcLen, 0.0,
97 VarianceFunction);
98
99 return sqrtf(acc);
100#endif /* ARM_DSP_AVAILABLE */
101 }
102
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100103 void MathUtils::FftInitF32(const uint16_t fftLen,
104 FftInstance& fftInstance,
105 const FftType type)
alexander3c798932021-03-26 21:42:19 +0000106 {
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100107 fftInstance.m_fftLen = fftLen;
108 fftInstance.m_initialised = false;
109 fftInstance.m_optimisedOptionAvailable = false;
110 fftInstance.m_type = type;
alexander3c798932021-03-26 21:42:19 +0000111
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100112#if ARM_DSP_AVAILABLE
113 arm_status status = ARM_MATH_ARGUMENT_ERROR;
114 switch (fftInstance.m_type) {
115 case FftType::real:
116 status = arm_rfft_fast_init_f32(&fftInstance.m_instanceReal, fftLen);
117 break;
118
119 case FftType::complex:
120 status = arm_cfft_init_f32(&fftInstance.m_instanceComplex, fftLen);
121 break;
122
123 default:
124 printf_err("Invalid FFT type\n");
125 return;
alexander3c798932021-03-26 21:42:19 +0000126 }
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100127
128 if (ARM_MATH_SUCCESS != status) {
129 printf_err("Failed to initialise FFT for len %d\n", fftLen);
130 } else {
131 fftInstance.m_optimisedOptionAvailable = true;
132 }
alexander3c798932021-03-26 21:42:19 +0000133#endif /* ARM_DSP_AVAILABLE */
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100134
135 if (!fftInstance.m_optimisedOptionAvailable) {
136 debug("Non optimised FFT will be used\n.");
137 }
138
139 fftInstance.m_initialised = true;
140 }
141
142 static void FftRealF32(std::vector<float>& input,
143 std::vector<float>& fftOutput)
144 {
145 const size_t inputLength = input.size();
146 const size_t halfLength = input.size() / 2;
147
148 fftOutput[0] = 0;
149 fftOutput[1] = 0;
150 for (size_t t = 0; t < inputLength; t++) {
151 fftOutput[0] += input[t];
152 fftOutput[1] += input[t] *
153 MathUtils::CosineF32(2 * M_PI * halfLength * t / inputLength);
154 }
155
156 for (size_t k = 1, j = 2; k < halfLength; ++k, j += 2) {
157 float sumReal = 0;
158 float sumImag = 0;
159
160 const float theta = static_cast<float>(2 * M_PI * k / inputLength);
161
162 for (size_t t = 0; t < inputLength; t++) {
163 const auto angle = static_cast<float>(t * theta);
164 sumReal += input[t] * MathUtils::CosineF32(angle);
165 sumImag += -input[t]* MathUtils::SineF32(angle);
166 }
167
168 /* Arrange output to [real0, realN/2, real1, im1, real2, im2, ...] */
169 fftOutput[j] = sumReal;
170 fftOutput[j + 1] = sumImag;
171 }
172 }
173
174 static void FftComplexF32(std::vector<float>& input,
175 std::vector<float>& fftOutput)
176 {
177 const size_t fftLen = input.size() / 2;
178 for (size_t k = 0; k < fftLen; k++) {
179 float sumReal = 0;
180 float sumImag = 0;
181 const auto theta = static_cast<float>(2 * M_PI * k / fftLen);
182 for (size_t t = 0; t < fftLen; t++) {
183 const auto angle = theta * t;
184 const auto cosine = MathUtils::CosineF32(angle);
185 const auto sine = MathUtils::SineF32(angle);
186 sumReal += input[t*2] * cosine + input[t*2 + 1] * sine;
187 sumImag += -input[t*2] * sine + input[t*2 + 1] * cosine;
188 }
189 fftOutput[k*2] = sumReal;
190 fftOutput[k*2 + 1] = sumImag;
191 }
alexander3c798932021-03-26 21:42:19 +0000192 }
193
194 void MathUtils::FftF32(std::vector<float>& input,
195 std::vector<float>& fftOutput,
196 arm::app::math::FftInstance& fftInstance)
197 {
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100198 if (!fftInstance.m_initialised) {
199 printf_err("FFT uninitialised\n");
200 return;
201 } else if (input.size() < fftInstance.m_fftLen) {
202 printf_err("FFT len: %" PRIu16 "; input len: %zu\n",
203 fftInstance.m_fftLen, input.size());
204 return;
205 } else if (fftOutput.size() < input.size()) {
206 printf_err("Output vector len insufficient to hold FFTs\n");
207 return;
alexander3c798932021-03-26 21:42:19 +0000208 }
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100209
210 switch (fftInstance.m_type) {
211 case FftType::real:
212
213#if ARM_DSP_AVAILABLE
214 if (fftInstance.m_optimisedOptionAvailable) {
215 arm_rfft_fast_f32(&fftInstance.m_instanceReal, input.data(), fftOutput.data(), 0);
216 return;
217 }
alexander3c798932021-03-26 21:42:19 +0000218#endif /* ARM_DSP_AVAILABLE */
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100219 FftRealF32(input, fftOutput);
220 return;
221
222 case FftType::complex:
223 if (input.size() < fftInstance.m_fftLen * 2) {
224 printf_err("Complex FFT instance should have input size >= (FFT len x 2)");
225 return;
226 }
227#if ARM_DSP_AVAILABLE
228 if (fftInstance.m_optimisedOptionAvailable) {
229 fftOutput = input; /* Complex function works in-place */
Richard Burton033c9152021-12-07 14:04:44 +0000230 arm_cfft_f32(&fftInstance.m_instanceComplex, fftOutput.data(), 0, 1);
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100231 return;
232 }
233#endif /* ARM_DSP_AVAILABLE */
234 FftComplexF32(input, fftOutput);
235 return;
236
237 default:
238 printf_err("Invalid FFT type\n");
239 return;
240 }
alexander3c798932021-03-26 21:42:19 +0000241 }
242
243 void MathUtils::VecLogarithmF32(std::vector <float>& input,
244 std::vector <float>& output)
245 {
246#if ARM_DSP_AVAILABLE
247 arm_vlog_f32(input.data(), output.data(),
248 output.size());
249#else /* ARM_DSP_AVAILABLE */
250 for (auto in = input.begin(), out = output.begin();
alexanderc350cdc2021-04-29 20:36:09 +0100251 in != input.end() && out != output.end(); ++in, ++out) {
alexander3c798932021-03-26 21:42:19 +0000252 *out = logf(*in);
253 }
254#endif /* ARM_DSP_AVAILABLE */
255 }
256
257 float MathUtils::DotProductF32(float* srcPtrA, float* srcPtrB,
258 const uint32_t srcLen)
259 {
260 float output = 0.f;
261
262#if ARM_DSP_AVAILABLE
263 arm_dot_prod_f32(srcPtrA, srcPtrB, srcLen, &output);
264#else /* ARM_DSP_AVAILABLE */
265 for (uint32_t i = 0; i < srcLen; ++i) {
266 output += *srcPtrA++ * *srcPtrB++;
267 }
268#endif /* ARM_DSP_AVAILABLE */
269
270 return output;
271 }
272
273 bool MathUtils::ComplexMagnitudeSquaredF32(float* ptrSrc,
274 const uint32_t srcLen,
275 float* ptrDst,
276 const uint32_t dstLen)
277 {
278 if (dstLen < srcLen/2) {
279 printf_err("dstLen must be greater than srcLen/2");
280 return false;
281 }
282
283#if ARM_DSP_AVAILABLE
284 arm_cmplx_mag_squared_f32(ptrSrc, ptrDst, srcLen/2);
285#else /* ARM_DSP_AVAILABLE */
Éanna Ó Catháin036f8c72021-09-01 15:44:56 +0100286 for (uint32_t j = 0; j < srcLen/2; ++j) {
alexander3c798932021-03-26 21:42:19 +0000287 const float real = *ptrSrc++;
288 const float im = *ptrSrc++;
289 *ptrDst++ = real*real + im*im;
290 }
291#endif /* ARM_DSP_AVAILABLE */
292 return true;
293 }
294
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000295 void MathUtils::SoftmaxF32(std::vector<float>& vec)
296 {
297 /* Fix for numerical stability and apply exp. */
298 auto start = vec.begin();
299 auto end = vec.end();
300
301 float maxValue = *std::max_element(start, end);
302 for (auto it = start; it != end; ++it) {
303 *it = std::exp((*it) - maxValue);
304 }
305
306 float sumExp = std::accumulate(start, end, 0.0f);
307
308 for (auto it = start; it != end; ++it) {
309 *it = (*it)/sumExp;
310 }
311 }
312
alexander3c798932021-03-26 21:42:19 +0000313} /* namespace math */
314} /* namespace app */
Kshitij Sisodia14ab8d42021-10-22 17:35:01 +0100315} /* namespace arm */