blob: 8b73e4225fe846dbda9bc5c0c3d2ef246c1f9a89 [file] [log] [blame]
liabar01dee53bc2021-10-29 15:59:04 +01001/*
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +01002 * SPDX-FileCopyrightText: Copyright 2021 - 2023 Arm Limited and/or its affiliates
3 * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
liabar01dee53bc2021-10-29 15:59:04 +01004 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "PlatformMath.hpp"
18#include <catch.hpp>
19#include <limits>
20#include <numeric>
21
22TEST_CASE("Test CosineF32")
23{
24 /*Test Constants: */
25 std::vector<double> inputA{
26 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1,
27 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21,
28 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32,
29 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43,
30 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54,
31 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65,
32 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76,
33 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87,
34 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98,
35 0.99, 1.0
36 };
37 std::vector<double> expectedResult{
38 1.0, 0.9995065603657316, 0.9980267284282716, 0.99556196460308,
39 0.9921147013144779, 0.9876883405951378, 0.9822872507286887,
40 0.9759167619387474, 0.9685831611286311, 0.9602936856769431,
41 0.9510565162951535, 0.9408807689542255, 0.9297764858882515,
42 0.9177546256839811, 0.9048270524660195, 0.8910065241883679,
43 0.8763066800438636, 0.8607420270039436, 0.8443279255020151,
44 0.8270805742745618, 0.8090169943749475, 0.7901550123756904,
45 0.7705132427757893, 0.7501110696304596, 0.7289686274214116,
46 0.7071067811865476, 0.6845471059286886, 0.6613118653236518,
47 0.6374239897486896, 0.6129070536529766, 0.5877852522924731,
48 0.5620833778521306, 0.5358267949789965, 0.5090414157503712,
49 0.48175367410171516, 0.4539904997395468, 0.42577929156507266,
50 0.39714789063478056, 0.3681245526846781, 0.3387379202452915,
51 0.30901699437494745, 0.2789911060392295, 0.24868988716485496,
52 0.2181432413965427, 0.18738131458572474, 0.15643446504023092,
53 0.12533323356430426, 0.0941083133185145, 0.06279051952931353,
54 0.031410759078128396, 6.123233995736766e-17, -0.03141075907812828,
55 -0.0627905195293134, -0.09410831331851438, -0.12533323356430437,
56 -0.15643446504023104, -0.18738131458572482, -0.21814324139654234,
57 -0.24868988716485463, -0.27899110603922916, -0.30901699437494734,
58 -0.33873792024529137, -0.368124552684678, -0.39714789063478045,
59 -0.4257792915650727, -0.4539904997395467, -0.48175367410171543,
60 -0.5090414157503713, -0.5358267949789969, -0.5620833778521304,
61 -0.587785252292473, -0.6129070536529763, -0.6374239897486897,
62 -0.6613118653236517, -0.6845471059286887, -0.7071067811865475,
63 -0.7289686274214113, -0.7501110696304596, -0.7705132427757891,
64 -0.7901550123756904, -0.8090169943749473, -0.8270805742745619,
65 -0.8443279255020149, -0.8607420270039435, -0.8763066800438634,
66 -0.8910065241883678, -0.9048270524660194, -0.9177546256839811,
67 -0.9297764858882513, -0.9408807689542255, -0.9510565162951535,
68 -0.9602936856769431, -0.9685831611286311, -0.9759167619387474,
69 -0.9822872507286886, -0.9876883405951377, -0.9921147013144778,
70 -0.99556196460308, -0.9980267284282716, -0.9995065603657316, -1.0
71 };
72
73 float tolerance = 10e-7;
74 for (size_t i = 0; i < inputA.size(); i++) {
75 CHECK (expectedResult[i] ==
76 Approx(arm::app::math::MathUtils::CosineF32(M_PI*inputA[i])).margin(tolerance));
77 }
78}
79
80TEST_CASE("Test SineF32")
81{
82/*Test Constants: */
83std::vector<double> inputA{
84 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1,
85 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21,
86 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32,
87 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43,
88 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54,
89 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65,
90 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76,
91 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87,
92 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98,
93 0.99, 1.0
94};
95std::vector<double> expectedResult{
96
97 0.0, 0.03141075907812829, 0.06279051952931337, 0.09410831331851431,
98 0.12533323356430426, 0.15643446504023087, 0.1873813145857246,
99 0.21814324139654256, 0.2486898871648548, 0.2789911060392293,
100 0.3090169943749474, 0.33873792024529137, 0.3681245526846779,
101 0.3971478906347806, 0.4257792915650727, 0.45399049973954675,
102 0.4817536741017153, 0.5090414157503713, 0.5358267949789967,
103 0.5620833778521306, 0.5877852522924731, 0.6129070536529764,
104 0.6374239897486896, 0.6613118653236518, 0.6845471059286886,
105 0.7071067811865475, 0.7289686274214116, 0.7501110696304596,
106 0.7705132427757893, 0.7901550123756903, 0.8090169943749475,
107 0.8270805742745618, 0.8443279255020151, 0.8607420270039436,
108 0.8763066800438637, 0.8910065241883678, 0.9048270524660196,
109 0.9177546256839811, 0.9297764858882513, 0.9408807689542255,
110 0.9510565162951535, 0.960293685676943, 0.9685831611286311,
111 0.9759167619387473, 0.9822872507286886, 0.9876883405951378,
112 0.9921147013144779, 0.99556196460308, 0.9980267284282716,
113 0.9995065603657316, 1.0, 0.9995065603657316,
114 0.9980267284282716, 0.99556196460308, 0.9921147013144778,
115 0.9876883405951377, 0.9822872507286886, 0.9759167619387474,
116 0.9685831611286312, 0.9602936856769431, 0.9510565162951536,
117 0.9408807689542255, 0.9297764858882513, 0.9177546256839813,
118 0.9048270524660195, 0.8910065241883679, 0.8763066800438635,
119 0.8607420270039436, 0.844327925502015, 0.827080574274562,
120 0.8090169943749475, 0.7901550123756905, 0.7705132427757893,
121 0.7501110696304597, 0.7289686274214114, 0.7071067811865476,
122 0.6845471059286888, 0.6613118653236518, 0.6374239897486899,
123 0.6129070536529764, 0.5877852522924732, 0.5620833778521305,
124 0.535826794978997, 0.5090414157503714, 0.4817536741017156,
125 0.45399049973954686, 0.4257792915650729, 0.3971478906347806,
126 0.36812455268467814, 0.3387379202452913, 0.3090169943749475,
127 0.2789911060392291, 0.24868988716485482, 0.21814324139654231,
128 0.18738131458572502, 0.15643446504023098, 0.12533323356430454,
129 0.09410831331851435, 0.06279051952931358, 0.031410759078128236,
130 1.2246467991473532e-16
131 };
132 float tolerance = 10e-4;
133 for (size_t i = 0; i < inputA.size(); i++) {
134 CHECK (expectedResult[i] ==
135 Approx(arm::app::math::MathUtils::SineF32(M_PI*inputA[i])).margin(tolerance));
136 }
137}
138
139TEST_CASE("Test SqrtF32")
140{
141 /*Test Constants: */
142 std::vector<float> inputA{0,1,2,9,M_PI};
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000143 size_t len = inputA.size();
liabar01dee53bc2021-10-29 15:59:04 +0100144 std::vector<float> expectedResult{0, 1, 1.414213562, 3, 1.772453851 };
145
146 for (size_t i=0; i < len; i++){
147 CHECK (expectedResult[i] == Approx(arm::app::math::MathUtils::SqrtF32(inputA[i])));
148 }
149}
150
151TEST_CASE("Test MeanF32")
152{
Richard Burtonc2911442022-04-22 09:08:21 +0100153 /* Test Constants: */
liabar01dee53bc2021-10-29 15:59:04 +0100154 std::vector<float> input
155 {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 1.000};
156
157 /* Manually calculated mean of above vector */
158 float expectedResult = 0.100;
159 CHECK (expectedResult == Approx(arm::app::math::MathUtils::MeanF32(input.data(), input.size())));
Richard Burtonc2911442022-04-22 09:08:21 +0100160
161 /* Mean of 0 */
162 std::vector<float> input2{1, 2, -1, -2};
163 float expectedResult2 = 0.0f;
164 CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::MeanF32(input2.data(), input2.size())));
165
166 /* All 0s */
167 std::vector<float> input3 = std::vector<float>(9, 0);
168 float expectedResult3 = 0.0f;
169 CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::MeanF32(input3.data(), input3.size())));
170
171 /* All 1s */
172 std::vector<float> input4 = std::vector<float>(9, 1);
173 float expectedResult4 = 1.0f;
174 CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::MeanF32(input4.data(), input4.size())));
liabar01dee53bc2021-10-29 15:59:04 +0100175}
176
177TEST_CASE("Test StdDevF32")
178{
179 /*Test Constants: */
180 /* Normally distributed sample data generated by numpy normal library */
181 std::vector<float> input
182 {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387,
183 1.74481176, -0.7612069, 0.3190391, -0.24937038, 1.46210794, -2.06014071,
184 -0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842,
185 0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434,
186 0.90085595, -0.68372786, -0.12289023, -0.93576943, -0.26788808, 0.53035547,
187 -0.69166075,-0.39675353, -0.6871727, -0.84520564, -0.67124613, -0.0126646,
188 -1.11731035, 0.2344157, 1.65980218, 0.74204416, -0.19183555, -0.88762896,
189 -0.74715829, 1.6924546, 0.05080775, -0.63699565, 0.19091548, 2.10025514,
190 0.12015895, 0.61720311
191 };
192 uint32_t inputLen = input.size();
193
194 /*Calculate mean using std library to avoid dependency on MathUtils::MeanF32 */
195 float mean = (std::accumulate(input.begin(), input.end(), 0.0f))/float(inputLen);
196 float output = arm::app::math::MathUtils::StdDevF32(input.data(), inputLen, mean);
197
198 /*Manually calculated standard deviation of above vector*/
199 float expectedResult = 0.969589282958136;
200
201 CHECK (expectedResult == Approx(output));
Richard Burtonc2911442022-04-22 09:08:21 +0100202
203 /* All 0s should have 0 std dev. */
204 std::vector<float> input2 = std::vector<float>(4, 0);
205 float expectedResult2 = 0.0f;
206 CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::StdDevF32(input2.data(), input2.size(), 0.0f)));
207
208 /* All 1s should have 0 std dev. */
209 std::vector<float> input3 = std::vector<float>(4, 1);
210 float expectedResult3 = 0.0f;
211 CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::StdDevF32(input3.data(), input3.size(), 1.0f)));
212
213 /* Manually calclualted std value */
214 std::vector<float> input4 {1, 2, 3, 4, 5, 6, 7, 8, 9, 0};
215 float mean2 = (std::accumulate(input4.begin(), input4.end(), 0.0f))/float(input4.size());
216 float expectedResult4 = 2.872281323;
217 CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::StdDevF32(input4.data(), input4.size(), mean2)));
liabar01dee53bc2021-10-29 15:59:04 +0100218}
219
220TEST_CASE("Test FFT32")
221{
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100222 constexpr size_t nElem = 512;
223
liabar01dee53bc2021-10-29 15:59:04 +0100224 /*Test Constants: */
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100225 std::vector<float> input_zeros(nElem, 0);
226 std::vector<float> input_ones(nElem, 1);
227
liabar01dee53bc2021-10-29 15:59:04 +0100228 /* Random numbers generated using numpy rand with range [0:1] */
229 std::vector<float> input_random{
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100230 0.42333686, 0.6547418, 0.8933691, 0.91466254, 0.5992143, 0.99474055, 0.97750413,
231 0.97160685, 0.72718734, 0.8699537, 0.643911, 0.09764466, 0.0050113136, 0.46823388,
232 0.13709934, 0.44892532, 0.59728205, 0.04055081, 0.5579888, 0.18445836, 0.66469765,
233 0.82715863, 0.91934484, 0.7844356, 0.23489648, 0.021708783, 0.67819905, 0.75761676,
234 0.48374954, 0.14006922, 0.87082034, 0.7694296, 0.80479276, 0.8241704, 0.95917296,
235 0.5758142, 0.16839339, 0.34290153, 0.5846108, 0.6878044, 0.1067114, 0.5198196,
236 0.4356897, 0.68049103, 0.12480807, 0.3538696, 0.06067087, 0.056964435, 0.5382167,
237 0.07761527, 0.6673144, 0.9045368, 0.11050189, 0.03530183, 0.07864744, 0.98752064,
238 0.014321936, 0.101833574, 0.43293256, 0.87102246, 0.52411795, 0.90232223, 0.49560344,
239 0.6803092, 0.2908511, 0.14653015, 0.99105513, 0.7057098, 0.09623502, 0.039713606,
240 0.88669086, 0.56018597, 0.90632766, 0.99241334, 0.18748309, 0.38991618, 0.6359827,
241 0.05665585, 0.732304, 0.2703365, 0.19014524, 0.5017947, 0.78862536, 0.81253093,
242 0.35050204, 0.2832596, 0.65221876, 0.59856164, 0.42758793, 0.78865635, 0.30943435,
243 0.93780816, 0.62568265, 0.35397422, 0.84209913, 0.48590583, 0.34837773, 0.5811646,
244 0.42924216, 0.26692122, 0.030709852, 0.84459823, 0.09085059, 0.29297647, 0.48539516,
245 0.33488297, 0.7877257, 0.8728821, 0.28454545, 0.7109578, 0.86097074, 0.8536262,
246 0.4978063, 0.5760398, 0.77506036, 0.7716988, 0.27041402, 0.52340513, 0.2055419,
247 0.8728235, 0.13492358, 0.79122984, 0.52998376, 0.33897072, 0.6426309, 0.8766521,
248 0.89287037, 0.74047667, 0.42341164, 0.67437655, 0.4682156, 0.67123246, 0.54287183,
249 0.3580476, 0.94756556, 0.2699457, 0.6131569, 0.75043845, 0.8115012, 0.49610943,
250 0.7108478, 0.90941435, 0.02233071, 0.37346774, 0.33732748, 0.46691266, 0.35784695,
251 0.39391598, 0.8556212, 0.884142, 0.11730601, 0.550112, 0.31513855, 0.69654715,
252 0.58585805, 0.4493127, 0.78515726, 0.8176612, 0.9846698, 0.32842383, 0.41843212,
253 0.48470423, 0.6757128, 0.95876855, 0.5989163, 0.13587572, 0.72886884, 0.88291156,
254 0.34402263, 0.66211045, 0.86188424, 0.21498202, 0.26397392, 0.67372984, 0.91386956,
255 0.7339788, 0.91308993, 0.1953016, 0.1539217, 0.214701, 0.58234113, 0.8019992,
256 0.63969976, 0.041050985, 0.7293308, 0.26341477, 0.54768014, 0.97596467, 0.12385198,
257 0.44149798, 0.5519762, 0.1697347, 0.577215, 0.8213594, 0.47874716, 0.64515114,
258 0.61467725, 0.18463866, 0.23890929, 0.51052976, 0.16807361, 0.53142565, 0.2414274,
259 0.41690814, 0.98815554, 0.6245643, 0.9477003, 0.24780034, 0.82469565, 0.8614785,
260 0.9565832, 0.062440686, 0.9710724, 0.039196696, 0.11030199, 0.35234734, 0.02065066,
261 0.12832293, 0.7328055, 0.48924434, 0.17247158, 0.5769348, 0.44146806, 0.53575355,
262 0.17258933, 0.6980237, 0.86494404, 0.50573164, 0.5033998, 0.71199447, 0.41353586,
263 0.26767612, 0.3789118, 0.046621118, 0.58491063, 0.22861995, 0.03134273, 0.53280216,
264 0.23382367, 0.07748905, 0.96875405, 0.6613716, 0.64087844, 0.8377165, 0.051519375,
265 0.68997836, 0.3776376, 0.43362603, 0.5358754, 0.51419014, 0.12823892, 0.26574057,
266 0.508808, 0.15734084, 0.78327274, 0.5045347, 0.5445746, 0.89297736, 0.8531272,
267 0.91270804, 0.87429863, 0.3965137, 0.13544834, 0.74269205, 0.80592203, 0.045050766,
268 0.13362087, 0.17090783, 0.02873757, 0.99339336, 0.6394376, 0.48203012, 0.70598215,
269 0.37082237, 0.39792424, 0.89938444, 0.312602, 0.48755112, 0.18220617, 0.17303479,
270 0.31954846, 0.78080165, 0.1755106, 0.68262285, 0.84665287, 0.8520143, 0.8459509,
271 0.39417005, 0.30087698, 0.81362164, 0.61927587, 0.32739028, 0.9023775, 0.27578092,
272 0.6830477, 0.15842387, 0.8473049, 0.43057114, 0.2019703, 0.20560141, 0.6237757,
273 0.60283095, 0.27645138, 0.26605442, 0.27985683, 0.41353813, 0.85139906, 0.71711886,
274 0.5444832, 0.73613757, 0.7397004, 0.7406752, 0.41016674, 0.31896713, 0.4541723,
275 0.2795807, 0.47941738, 0.00504193, 0.89091027, 0.8097144, 0.63033766, 0.37252298,
276 0.9132861, 0.5102532, 0.04104481, 0.30368647, 0.21573475, 0.99520445, 0.5047808,
277 0.6868845, 0.99881023, 0.30377692, 0.2554386, 0.47201005, 0.11120686, 0.10077732,
278 0.1853349, 0.49159425, 0.3938629, 0.8989509, 0.9887155, 0.698771, 0.695701,
279 0.78368753, 0.52537227, 0.19451462, 0.3659248, 0.1968508, 0.7751828, 0.33103722,
280 0.40406147, 0.37832898, 0.68663514, 0.32225925, 0.41771907, 0.034218453, 0.42808908,
281 0.20685343, 0.1861495, 0.045986768, 0.8532299, 0.17200677, 0.44670314, 0.56831235,
282 0.5388232, 0.5430553, 0.69175136, 0.6462231, 0.42827028, 0.10050113, 0.30627027,
283 0.9967943, 0.6684778, 0.5928422, 0.63392985, 0.99123496, 0.79301435, 0.7936309,
284 0.42839453, 0.39781123, 0.22329247, 0.0122212395, 0.2807108, 0.19812097, 0.5576105,
285 0.115653396, 0.3732018, 0.7622857, 0.19847734, 0.5310287, 0.7298145, 0.5518292,
286 0.9117333, 0.13215758, 0.33716795, 0.42372775, 0.6779287, 0.35799992, 0.097887225,
287 0.20171605, 0.9948177, 0.1829232, 0.80349857, 0.9807098, 0.22959666, 0.67322475,
288 0.63094735, 0.93454355, 0.15962408, 0.04335433, 0.47104993, 0.36784375, 0.45258796,
289 0.93415564, 0.1655446, 0.7195017, 0.76236975, 0.3846913, 0.01330617, 0.84716374,
290 0.1227003, 0.65102947, 0.6632434, 0.3728453, 0.4222391, 0.6942989, 0.16014872,
291 0.10798196, 0.94033676, 0.026525471, 0.8379024, 0.5484514, 0.13500613, 0.22919805,
292 0.7001831, 0.6573261, 0.38086265, 0.8725666, 0.35077834, 0.28415123, 0.42283052,
293 0.668379, 0.9769895, 0.37621376, 0.646407, 0.11188069, 0.17129017, 0.7441628,
294 0.25617477, 0.7751679, 0.8565412, 0.67631435, 0.45213568, 0.61896557, 0.3387995,
295 0.51607716, 0.60779697, 0.16428445, 0.5080923, 0.13012086, 0.61184275, 0.7690249,
296 0.9578811, 0.67365676, 0.16241212, 0.97157824, 0.5595742, 0.75936574, 0.6043881,
297 0.2149638, 0.4925318, 0.58727825, 0.97953695, 0.01605968, 0.2819307, 0.6448378,
298 0.4265335, 0.661541, 0.3976571, 0.40607136, 0.46425515, 0.2055872, 0.2716193,
299 0.4132582, 0.8372537, 0.37787434, 0.082228854, 0.7985557, 0.9718134, 0.35222608,
300 0.4853643, 0.2569464, 0.14783978, 0.4889042, 0.62900156, 0.19994198, 0.4618481,
301 0.21673755, 0.51749533, 0.1260157, 0.83759904, 0.36438805, 0.6704668, 0.22010763,
302 0.2359318, 0.53004104, 0.9723652, 0.91218954, 0.9153926, 0.48207277, 0.34850466,
303 0.8939421};
liabar01dee53bc2021-10-29 15:59:04 +0100304 std::vector<std::vector<float>> input_vectors{input_zeros,
305 input_ones,
306 input_random};
307
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100308 std::vector<float> output_zeros(nElem, 0);
309 std::vector<float> output_ones(nElem, 0);
310 std::vector<float> output_random(nElem, 0);
liabar01dee53bc2021-10-29 15:59:04 +0100311 std::vector<std::vector<float>> output_vectors{output_zeros,
312 output_ones,
313 output_random};
314
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100315 std::vector<float> expected_result_zeros(nElem, 0);
316 std::vector<float> expected_result_ones(nElem, 0);
317 expected_result_ones[0] = static_cast<float>(nElem);
liabar01dee53bc2021-10-29 15:59:04 +0100318
319 /* Values are stored as [real0, realN/2, real1, im1, real2, im2, ...] */
320 std::vector<float> expected_result_random{
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100321 259.510161, 2.59796867, 2.55982143, -5.91349888, -1.80049237, 1.09902763,
322 4.0094324, 7.76684892, 4.32617219, -4.33636417, 2.98128463, -0.83763449,
323 2.92973078, -3.75655459, 2.27203161, -4.61106145, 5.55562176, -4.71880166,
324 2.13693416, -2.20496619, 2.31174036, -3.52991041, -0.61687068, 2.43455407,
325 4.76317833, 8.66518565, 1.72350562, -1.33641312, 5.82836675, 0.89396187,
326 8.15031483, 4.34599034, 5.99780199, 2.94900065, -0.0234462045, -5.03789597,
327 12.190702, -6.47012928, 1.24434715, 0.23621713, 0.920921279, -9.20510398,
328 -0.773267254, -3.72141078, 6.28883709, 4.00634065, 4.46491682, 0.74625307,
329 -0.587506158, -4.22833058, 3.15189786, -1.82518672, -6.9378226, 2.4170692,
330 3.23045185, 3.33383799, 0.0510059531, -3.4233929, -2.91651323, -0.0258584,
331 5.84499843, -9.51454903, -14.9214047, 5.52200123, 4.5217959, -7.08268703,
332 0.51677542, -2.90878759, 5.04314682, -1.16928599, -10.7329243, -0.2719951,
333 -3.95269565, -2.32475678, -4.11031641, -2.20538835, -0.589005095, -5.65483456,
334 -10.8927018, 5.74801823, -5.72520347, -3.94970165, -0.518407515, 1.23622633,
335 9.56297959, 0.24424306, -3.74306351, 9.63476301, -4.74493837, -0.35443496,
336 6.54760504, 1.16188913, -13.341695, 7.19088609, -2.560458, -0.49557866,
337 2.93460322, 6.91076746, -0.284779221, 6.59958391, 1.70963995, -7.67293252,
338 4.1850079, -0.14627552, -1.24855113, -1.43322867, 0.360644904, -4.11521374,
339 -10.0628421, 2.87531563, -9.34809732, -0.58251846, 3.61799848, -5.10288284,
340 -5.96239076, 1.99792128, -0.0783229243, 1.81741166, 2.32709681, 0.68487206,
341 -3.08398468, 1.5177629, 1.41015388, 4.51146401, 1.90769911, 0.56093423,
342 3.58389141, -0.14974575, -2.20163907, -1.62177814, -1.91904127, 1.94645907,
343 -0.13772293, -4.30291678, -2.61843435, 0.12691241, 3.28959117, 4.7309582,
344 -2.93995652, 0.39835926, 2.89711768, 1.42284586, -5.13129145, -7.26477374,
345 3.74616158, -2.59659457, 3.8574875, -2.93737277, 3.17748694, -4.45041455,
346 -2.68466437, 1.37377726, 1.60008368, 1.63787578, -1.95661401, -6.34937202,
347 -2.62744282, -5.20892662, 0.890553959, -6.37113573, 2.35885332, -10.04547561,
348 0.329866159, 1.89217741, 0.882516491, -3.53298728, -2.22525608, -5.64794388,
349 -5.19226843, 2.5971315, 4.49346648, -0.20428409, -6.14851885, -11.90893875,
350 3.75899776, -1.86910056, 2.78518535, -2.6359501, -2.13423317, 4.86509946,
351 -2.37625499, 7.42404308, 6.71175474, 6.06191618, 2.59014379, 4.76329698,
352 9.19140042, 9.69149015, 3.33307819, -4.03094924, 2.12988453, -0.15820258,
353 -10.3422801, -3.04462388, 3.59852152, -2.00887343, -3.69998656, 0.90050102,
354 0.679959099, -1.88604949, 1.24235316, 0.41309537, 6.13876866, -7.1040085,
355 6.17728674, 1.91667103, -1.32895472, -0.17674661, -6.94720428, 3.10502593,
356 -2.33990738, -1.27840434, 3.2144252, 2.14102714, 2.37498837, 3.8158066,
357 -2.24107675, -5.52527559, -2.9569793, -0.50367608, 3.01687661, 7.08195792,
358 6.7860479, -3.94154162, 2.24402195, 4.60132638, 3.42211139, 4.17689039,
359 -1.17277194, 2.15404472, -2.3748193, 1.42611867, -0.463033506, 3.21563035,
360 1.38662123, 3.98598717, -3.75283402, -2.47600433, -1.97290542, 2.83361487,
361 -0.845662834, 5.57411581, -0.972981483, -11.394208, 1.88220611, -0.80225125,
362 -0.434295854, -8.2954126, 3.81795409, -3.17146, -4.61994107, -1.59820505,
363 -5.98834455, -4.93129451, -0.513862996, -0.15649305, -5.59094391, 6.25244435,
364 -6.59974456, 13.17193115, 4.48609092, 1.64741879, 7.40985006, 0.44896188,
365 3.81058449, -0.76425931, -5.47938416, 4.01447941, -3.21535548, -1.45542238,
366 0.72274083, -0.23983128, -4.32373034, 0.1337671, -5.89365226, -3.18756318,
367 7.90979161, 5.27570134, -3.43094553, -6.00826981, 1.17932561, -3.50027177,
368 0.181306385, 1.1062498, 0.723650536, -1.55500613, -3.88047911, -2.43746762,
369 -6.81565579, 2.16343352, 2.46366137, 2.38704469, -2.55106395, 6.5091449,
370 -2.06510578, 11.11320924, 2.06649835, -1.05026064, 1.63564303, -0.04638729,
371 1.45053876, 0.43730146, 1.25027939, 0.79932743, 2.81088838, 6.95136058,
372 -4.41417255, 2.89610628, 1.15426258, -2.60704937, -2.77744882, 4.12872365,
373 -2.98288336, 6.75607352, 2.36553382, -2.10540332, -7.30042988, -5.44897893,
374 3.44048454, 4.29726231, 2.181995, -0.80126759, -4.04051175, 4.57584864,
375 0.956312116, -4.45183318, 3.42348929, -9.84138181, 8.69604433, -6.6481311,
376 0.468232735, 1.41031176, -1.240857, 10.61672181, 0.356591473, 10.51631421,
377 -0.99743547, 2.72157537, 8.63583929, -2.19404252, -1.53605811, -4.41068581,
378 2.05371873, 1.25665769, 1.65289503, 4.52520582, -0.535062642, 0.82084677,
379 -11.0079476, -5.09361474, 9.63129107, 3.90056638, -4.19779738, 0.06565745,
380 2.42526917, 2.5854233, -3.66709357, 3.80502971, -0.101489353, -6.85423228,
381 -13.9361494, -0.43904617, 6.01800968, -1.30751495, -4.75122234, 2.74740671,
382 5.54971138, 9.43409003, -0.994733058, -3.0096825, -8.60263376, 0.36653762,
383 3.53318614, -2.69194556, -8.9514574, -4.71570923, 5.15417709, 2.68645385,
384 -2.78042293, 8.21739385, 0.590225003, 2.13319153, 1.72158888, 0.18114627,
385 3.92269446, 3.3525857, -3.40313825, 4.39280934, -1.70368966, 1.29121245,
386 -3.11326453, -1.85941318, 4.57078881, 0.72531039, 6.00445664, 4.9588524,
387 -3.32944491, -0.02080722, -7.42374632, -3.23290026, -0.81614579, 3.55935439,
388 -0.619206533, 2.42859073, 2.21486456, 3.76402487, 3.90930695, -3.61610186,
389 -0.812712547, -14.63377988, 1.14460823, -3.14089899, 3.18097435, 1.21957751,
390 2.85181833, 0.89990235, -4.32147361, -5.54219361, -1.12253677, 2.96141081,
391 -4.4257707, 3.17282306, 4.9174671, 1.16977744, -4.55148089, -2.82520179,
392 -1.71684103, 1.91487668, 0.770726836, 0.78534837, -5.91048566, -4.8288477,
393 -1.35560162, -3.60938315, 1.15812301, 2.44299541, 1.3611519, -5.40950935,
394 7.08292127, 0.27720591, -0.160210828, 2.75862348, -1.57403782, 9.97207524,
395 -2.08957576, 8.70299964, -5.33004663, 4.1547783, 3.51580675, -5.10788085,
396 4.37938353, -3.73449894, 1.44673271, 0.51941469, 0.852232446, -1.1134965,
397 -1.43972745, -1.62952127, -2.50759973, 1.19012213, 0.572772282, -2.71833059,
398 -6.8471899, 4.2621535, 1.58954734, -0.53827818, -0.144624396, 7.63866979,
399 0.410423977, -2.4785678, -5.02681867, -2.03469811, 0.959505727, 2.68589705,
400 3.20889444, -10.76452533, -3.84771551, 2.49189796, -3.19895938, -3.49948794,
401 -2.6723897, 5.11386526, -3.85957031, -1.40741978, 0.176663166, -11.7111276,
402 0.639997364, -1.30321198, 3.20767633, 1.65750671, -11.6187257, -4.36634782,
403 -3.18675281, -4.89279155, -4.08760307, 2.19269283, -1.5892487, 0.17948212,
404 4.81376107, 2.01871001, -0.324211095, -0.2790092, 1.12603878, -3.61503491,
405 -2.86982317, -3.03634532, 8.0771391, 2.21302089, 2.91496011, -2.58564072,
406 0.0, 0.0};
liabar01dee53bc2021-10-29 15:59:04 +0100407
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100408 std::vector<std::vector<float>> expected_results_vectors{
409 expected_result_zeros, expected_result_ones, expected_result_random};
liabar01dee53bc2021-10-29 15:59:04 +0100410 arm::app::math::FftInstance fftInstance;
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100411
liabar01dee53bc2021-10-29 15:59:04 +0100412 /* Iterate over each of the input vectors, calculate FFT and compare with corresponding expected_results vectors */
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000413 for (size_t j = 0; j < input_vectors.size(); j++) {
liabar01dee53bc2021-10-29 15:59:04 +0100414 uint16_t fftLen = input_vectors[j].size();
415 arm::app::math::MathUtils::FftInitF32(fftLen, fftInstance);
416 arm::app::math::MathUtils::FftF32(input_vectors[j], output_vectors[j], fftInstance);
417
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100418 const float tolerance = 10e-4;
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000419 for (size_t i = 0; i < fftLen/2; i++) {
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100420 CHECK(output_vectors[j][i] == Approx(expected_results_vectors[j][i]).margin(tolerance));
421 }
422 }
liabar01dee53bc2021-10-29 15:59:04 +0100423
Kshitij Sisodia1e6c6942023-05-05 10:56:37 +0100424 /* Test inverse FFTs using the forward FFT for complex numbers.
425 * IFFT(XVec) = (1/N)(Conj(FFT(Conj(Xvec)))) */
426 for (size_t j = 0; j < input_vectors.size(); j++) {
427 const uint16_t fftLen = input_vectors[j].size();
428 const size_t inputSz = fftLen * 2;
429
430 /* This vector will populate the input for FFT for complex numbers. */
431 std::vector<float> inputWithConjugates(inputSz);
432
433 /* We expect the output of this test to return the original input. */
434 std::vector<float> expectedOutputVector = input_vectors[j];
435
436 /* Placeholder for output vector. */
437 std::vector<float> outputVector(inputWithConjugates.size());
438
439 /* Populate the 0 and N/2 elements (these will be real numbers
440 * only - no imaginary parts. */
441 inputWithConjugates[0] = expected_results_vectors[j][0];
442 inputWithConjugates[fftLen] = expected_results_vectors[j][1];
443
444 /* Populate the rest of the elements - conjugates of the original for the left mirror
445 * and the right side with what the left would have been, i.e.,
446 * conjugate(conjugate(X_left)). */
447 for (size_t i = 2; i < fftLen; i += 2) {
448 inputWithConjugates[i] = expected_results_vectors[j][i];
449 inputWithConjugates[i + 1] = 0 - expected_results_vectors[j][i + 1];
450 inputWithConjugates[fftLen + i] = expected_results_vectors[j][fftLen - i];
451 inputWithConjugates[fftLen + i + 1] = expected_results_vectors[j][fftLen - i + 1];
452 }
453
454 arm::app::math::MathUtils::FftInitF32(
455 fftLen, fftInstance, arm::app::math::FftType::complex);
456 arm::app::math::MathUtils::FftF32(inputWithConjugates, outputVector, fftInstance);
457
458 const float tolerance = 0.1;
459 for (size_t i = 0; i < expectedOutputVector.size(); i++) {
460
461 /* The number returned here will be nElem times the output. */
462 CHECK(outputVector[i * 2] / static_cast<float>(nElem) ==
463 Approx(expectedOutputVector[i]).margin(tolerance));
464
465 /* The imaginary part here should be close to 0 as the original input
466 * we supplied was real. */
467 CHECK(outputVector[i * 2 + 1] / nElem == Approx(0.f).margin(tolerance));
liabar01dee53bc2021-10-29 15:59:04 +0100468 }
469 }
470}
471
472TEST_CASE("Test VecLogarithmF32")
473{
474 /*Test Constants: */
475 std::vector<float> input =
476 { 0.1e-10, 0.5, 1, M_PI, M_E };
477 std::vector<float> expectedResult =
478 {-25.328436, -0.693147181, 0, 1.144729886, 1};
479 std::vector<float> output(input.size());
480
481 arm::app::math::MathUtils::VecLogarithmF32(input,output);
482
483 for (size_t i = 0; i < input.size(); i++)
484 CHECK (expectedResult[i] == Approx(output[i]));
485}
486
487TEST_CASE("Test DotProductF32")
488{
489 /*Test Constants: */
490 std::vector<float> inputA
491 {1,1,1,0,0,0};
492 std::vector<float> inputB
493 {0,0,0,1,1,1};
494 uint32_t len = inputA.size();
495
496 float expectedResult = 0;
497 float dot_prod = arm::app::math::MathUtils::DotProductF32(inputA.data(), inputB.data(), len);
498 CHECK(dot_prod == expectedResult);
499}
500
501TEST_CASE("Test ComplexMagnitudeSquaredF32")
502{
503 /*Test Constants: */
504 std::vector<float> input
505 {0.0, 0.0, 0.5, 0.5,1,1};
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000506 size_t inputLen = input.size();
liabar01dee53bc2021-10-29 15:59:04 +0100507
508 std::vector<float> expectedResult
509 {0.0, 0.5, 2,};
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000510 size_t outputLen = inputLen/2;
liabar01dee53bc2021-10-29 15:59:04 +0100511 std::vector<float>output(outputLen);
512
513 /* Pass pointers to input/output vectors as this function over-writes the first half
514 * of the input vector with output results */
515 arm::app::math::MathUtils::ComplexMagnitudeSquaredF32(input.data(), inputLen, output.data(), outputLen);
516
Kshitij Sisodiab178b282022-01-04 13:37:53 +0000517 for (size_t i = 0; i < outputLen; i++) {
liabar01dee53bc2021-10-29 15:59:04 +0100518 CHECK (expectedResult[i] == Approx(output[i]));
Kshitij Sisodiab178b282022-01-04 13:37:53 +0000519 }
520}
521
522/**
523 * @brief Simple function to test the Softmax function
524 *
525 * @param input Input vector
526 * @param goldenOutput Expected output vector
527 */
528static void TestSoftmaxF32(
529 const std::vector<float>& input,
530 const std::vector<float>& goldenOutput)
531{
532 std::vector<float> output = input; /* Function modifies the vector in-place */
533 arm::app::math::MathUtils::SoftmaxF32(output);
534
535 for (size_t i = 0; i < goldenOutput.size(); ++i) {
536 CHECK(goldenOutput[i] == Approx(output[i]));
537 }
538
539 REQUIRE(output.size() == goldenOutput.size());
540}
541
542TEST_CASE("Test SoftmaxF32")
543{
544 SECTION("Simple series") {
545 const std::vector<float> input {
546 0.0, 1.0, 2.0, 3.0, 4.0,
547 5.0, 6.0, 7.0, 8.0, 9.0
548 };
549
550 const std::vector<float> expectedOutput {
551 7.80134161e-05, 2.12062451e-04,
552 5.76445508e-04, 1.56694135e-03,
553 4.25938820e-03, 1.15782175e-02,
554 3.14728583e-02, 8.55520989e-02,
555 2.32554716e-01, 6.32149258e-01
556 };
557
558 TestSoftmaxF32(input, expectedOutput);
559 }
560
561 SECTION("Random series") {
562 const std::vector<float> input {
563 0.8810943246170809, 0.5877587675947015,
564 0.6841546454788743, 0.4155920960071594,
565 0.9799415323651671, 0.5066432973545711,
566 0.3846024252355448, 0.4568689569632123,
567 0.3284413744557605, 0.49152323726213554
568 };
569
570 const std::vector<float> expectedOutput {
571 0.13329595, 0.09940837,
572 0.10946799, 0.08368583,
573 0.14714509, 0.09166319,
574 0.08113220, 0.08721240,
575 0.07670132, 0.09028766
576 };
577
578 TestSoftmaxF32(input, expectedOutput);
579 }
580
581 SECTION("Series with large STD") {
582 const std::vector<float> input {
583 0.001, 1000.000
584 };
585
586 const std::vector<float> expectedOutput {
587 0.000, 1.000
588 };
589
590 TestSoftmaxF32(input, expectedOutput);
591 }
liabar01dee53bc2021-10-29 15:59:04 +0100592}