blob: ab1153f942347a304da3c285bcc9658f0ca4d955 [file] [log] [blame]
liabar01dee53bc2021-10-29 15:59:04 +01001/*
Isabella Gottardic64f5062022-01-21 15:27:13 +00002 * Copyright (c) 2021 - 2022 Arm Limited. All rights reserved.
liabar01dee53bc2021-10-29 15:59:04 +01003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "PlatformMath.hpp"
18#include <catch.hpp>
19#include <limits>
20#include <numeric>
21
22TEST_CASE("Test CosineF32")
23{
24 /*Test Constants: */
25 std::vector<double> inputA{
26 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1,
27 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21,
28 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32,
29 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43,
30 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54,
31 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65,
32 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76,
33 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87,
34 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98,
35 0.99, 1.0
36 };
37 std::vector<double> expectedResult{
38 1.0, 0.9995065603657316, 0.9980267284282716, 0.99556196460308,
39 0.9921147013144779, 0.9876883405951378, 0.9822872507286887,
40 0.9759167619387474, 0.9685831611286311, 0.9602936856769431,
41 0.9510565162951535, 0.9408807689542255, 0.9297764858882515,
42 0.9177546256839811, 0.9048270524660195, 0.8910065241883679,
43 0.8763066800438636, 0.8607420270039436, 0.8443279255020151,
44 0.8270805742745618, 0.8090169943749475, 0.7901550123756904,
45 0.7705132427757893, 0.7501110696304596, 0.7289686274214116,
46 0.7071067811865476, 0.6845471059286886, 0.6613118653236518,
47 0.6374239897486896, 0.6129070536529766, 0.5877852522924731,
48 0.5620833778521306, 0.5358267949789965, 0.5090414157503712,
49 0.48175367410171516, 0.4539904997395468, 0.42577929156507266,
50 0.39714789063478056, 0.3681245526846781, 0.3387379202452915,
51 0.30901699437494745, 0.2789911060392295, 0.24868988716485496,
52 0.2181432413965427, 0.18738131458572474, 0.15643446504023092,
53 0.12533323356430426, 0.0941083133185145, 0.06279051952931353,
54 0.031410759078128396, 6.123233995736766e-17, -0.03141075907812828,
55 -0.0627905195293134, -0.09410831331851438, -0.12533323356430437,
56 -0.15643446504023104, -0.18738131458572482, -0.21814324139654234,
57 -0.24868988716485463, -0.27899110603922916, -0.30901699437494734,
58 -0.33873792024529137, -0.368124552684678, -0.39714789063478045,
59 -0.4257792915650727, -0.4539904997395467, -0.48175367410171543,
60 -0.5090414157503713, -0.5358267949789969, -0.5620833778521304,
61 -0.587785252292473, -0.6129070536529763, -0.6374239897486897,
62 -0.6613118653236517, -0.6845471059286887, -0.7071067811865475,
63 -0.7289686274214113, -0.7501110696304596, -0.7705132427757891,
64 -0.7901550123756904, -0.8090169943749473, -0.8270805742745619,
65 -0.8443279255020149, -0.8607420270039435, -0.8763066800438634,
66 -0.8910065241883678, -0.9048270524660194, -0.9177546256839811,
67 -0.9297764858882513, -0.9408807689542255, -0.9510565162951535,
68 -0.9602936856769431, -0.9685831611286311, -0.9759167619387474,
69 -0.9822872507286886, -0.9876883405951377, -0.9921147013144778,
70 -0.99556196460308, -0.9980267284282716, -0.9995065603657316, -1.0
71 };
72
73 float tolerance = 10e-7;
74 for (size_t i = 0; i < inputA.size(); i++) {
75 CHECK (expectedResult[i] ==
76 Approx(arm::app::math::MathUtils::CosineF32(M_PI*inputA[i])).margin(tolerance));
77 }
78}
79
80TEST_CASE("Test SineF32")
81{
82/*Test Constants: */
83std::vector<double> inputA{
84 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1,
85 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21,
86 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32,
87 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43,
88 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54,
89 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65,
90 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76,
91 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87,
92 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98,
93 0.99, 1.0
94};
95std::vector<double> expectedResult{
96
97 0.0, 0.03141075907812829, 0.06279051952931337, 0.09410831331851431,
98 0.12533323356430426, 0.15643446504023087, 0.1873813145857246,
99 0.21814324139654256, 0.2486898871648548, 0.2789911060392293,
100 0.3090169943749474, 0.33873792024529137, 0.3681245526846779,
101 0.3971478906347806, 0.4257792915650727, 0.45399049973954675,
102 0.4817536741017153, 0.5090414157503713, 0.5358267949789967,
103 0.5620833778521306, 0.5877852522924731, 0.6129070536529764,
104 0.6374239897486896, 0.6613118653236518, 0.6845471059286886,
105 0.7071067811865475, 0.7289686274214116, 0.7501110696304596,
106 0.7705132427757893, 0.7901550123756903, 0.8090169943749475,
107 0.8270805742745618, 0.8443279255020151, 0.8607420270039436,
108 0.8763066800438637, 0.8910065241883678, 0.9048270524660196,
109 0.9177546256839811, 0.9297764858882513, 0.9408807689542255,
110 0.9510565162951535, 0.960293685676943, 0.9685831611286311,
111 0.9759167619387473, 0.9822872507286886, 0.9876883405951378,
112 0.9921147013144779, 0.99556196460308, 0.9980267284282716,
113 0.9995065603657316, 1.0, 0.9995065603657316,
114 0.9980267284282716, 0.99556196460308, 0.9921147013144778,
115 0.9876883405951377, 0.9822872507286886, 0.9759167619387474,
116 0.9685831611286312, 0.9602936856769431, 0.9510565162951536,
117 0.9408807689542255, 0.9297764858882513, 0.9177546256839813,
118 0.9048270524660195, 0.8910065241883679, 0.8763066800438635,
119 0.8607420270039436, 0.844327925502015, 0.827080574274562,
120 0.8090169943749475, 0.7901550123756905, 0.7705132427757893,
121 0.7501110696304597, 0.7289686274214114, 0.7071067811865476,
122 0.6845471059286888, 0.6613118653236518, 0.6374239897486899,
123 0.6129070536529764, 0.5877852522924732, 0.5620833778521305,
124 0.535826794978997, 0.5090414157503714, 0.4817536741017156,
125 0.45399049973954686, 0.4257792915650729, 0.3971478906347806,
126 0.36812455268467814, 0.3387379202452913, 0.3090169943749475,
127 0.2789911060392291, 0.24868988716485482, 0.21814324139654231,
128 0.18738131458572502, 0.15643446504023098, 0.12533323356430454,
129 0.09410831331851435, 0.06279051952931358, 0.031410759078128236,
130 1.2246467991473532e-16
131 };
132 float tolerance = 10e-4;
133 for (size_t i = 0; i < inputA.size(); i++) {
134 CHECK (expectedResult[i] ==
135 Approx(arm::app::math::MathUtils::SineF32(M_PI*inputA[i])).margin(tolerance));
136 }
137}
138
139TEST_CASE("Test SqrtF32")
140{
141 /*Test Constants: */
142 std::vector<float> inputA{0,1,2,9,M_PI};
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000143 size_t len = inputA.size();
liabar01dee53bc2021-10-29 15:59:04 +0100144 std::vector<float> expectedResult{0, 1, 1.414213562, 3, 1.772453851 };
145
146 for (size_t i=0; i < len; i++){
147 CHECK (expectedResult[i] == Approx(arm::app::math::MathUtils::SqrtF32(inputA[i])));
148 }
149}
150
151TEST_CASE("Test MeanF32")
152{
153 /*Test Constants: */
154 std::vector<float> input
155 {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 1.000};
156
157 /* Manually calculated mean of above vector */
158 float expectedResult = 0.100;
159 CHECK (expectedResult == Approx(arm::app::math::MathUtils::MeanF32(input.data(), input.size())));
160}
161
162TEST_CASE("Test StdDevF32")
163{
164 /*Test Constants: */
165 /* Normally distributed sample data generated by numpy normal library */
166 std::vector<float> input
167 {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387,
168 1.74481176, -0.7612069, 0.3190391, -0.24937038, 1.46210794, -2.06014071,
169 -0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842,
170 0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434,
171 0.90085595, -0.68372786, -0.12289023, -0.93576943, -0.26788808, 0.53035547,
172 -0.69166075,-0.39675353, -0.6871727, -0.84520564, -0.67124613, -0.0126646,
173 -1.11731035, 0.2344157, 1.65980218, 0.74204416, -0.19183555, -0.88762896,
174 -0.74715829, 1.6924546, 0.05080775, -0.63699565, 0.19091548, 2.10025514,
175 0.12015895, 0.61720311
176 };
177 uint32_t inputLen = input.size();
178
179 /*Calculate mean using std library to avoid dependency on MathUtils::MeanF32 */
180 float mean = (std::accumulate(input.begin(), input.end(), 0.0f))/float(inputLen);
181 float output = arm::app::math::MathUtils::StdDevF32(input.data(), inputLen, mean);
182
183 /*Manually calculated standard deviation of above vector*/
184 float expectedResult = 0.969589282958136;
185
186 CHECK (expectedResult == Approx(output));
187}
188
189TEST_CASE("Test FFT32")
190{
191 /*Test Constants: */
192 std::vector<float> input_zeros(512, 0);
193 std::vector<float> input_ones(512, 1);
194 /* Random numbers generated using numpy rand with range [0:1] */
195 std::vector<float> input_random{
196 0.42333686,0.6547418,0.8933691,0.91466254,0.5992143,0.99474055,0.97750413,0.97160685,
197 0.72718734,0.8699537,0.643911,0.09764466,0.0050113136,0.46823388,0.13709934,0.44892532,
198 0.59728205,0.04055081,0.5579888,0.18445836,0.66469765,0.82715863,0.91934484,0.7844356,
199 0.23489648,0.021708783,0.67819905,0.75761676,0.48374954,0.14006922,0.87082034,0.7694296,
200 0.80479276,0.8241704,0.95917296,0.5758142,0.16839339,0.34290153,0.5846108,0.6878044,
201 0.1067114,0.5198196,0.4356897,0.68049103,0.12480807,0.3538696,0.06067087,0.056964435,
202 0.5382167,0.07761527,0.6673144,0.9045368,0.11050189,0.03530183,0.07864744,0.98752064,
203 0.014321936,0.101833574,0.43293256,0.87102246,0.52411795,0.90232223,0.49560344,0.6803092,
204 0.2908511,0.14653015,0.99105513,0.7057098,0.09623502,0.039713606,0.88669086,0.56018597,
205 0.90632766,0.99241334,0.18748309,0.38991618,0.6359827,0.05665585,0.732304,0.2703365,
206 0.19014524,0.5017947,0.78862536,0.81253093,0.35050204,0.2832596,0.65221876,0.59856164,
207 0.42758793,0.78865635,0.30943435,0.93780816,0.62568265,0.35397422,0.84209913,0.48590583,
208 0.34837773,0.5811646,0.42924216,0.26692122,0.030709852,0.84459823,0.09085059,0.29297647,
209 0.48539516,0.33488297,0.7877257,0.8728821,0.28454545,0.7109578,0.86097074,0.8536262,
210 0.4978063,0.5760398,0.77506036,0.7716988,0.27041402,0.52340513,0.2055419,0.8728235,
211 0.13492358,0.79122984,0.52998376,0.33897072,0.6426309,0.8766521,0.89287037,0.74047667,
212 0.42341164,0.67437655,0.4682156,0.67123246,0.54287183,0.3580476,0.94756556,0.2699457,
213 0.6131569,0.75043845,0.8115012,0.49610943,0.7108478,0.90941435,0.02233071,0.37346774,
214 0.33732748,0.46691266,0.35784695,0.39391598,0.8556212,0.884142,0.11730601,0.550112,
215 0.31513855,0.69654715,0.58585805,0.4493127,0.78515726,0.8176612,0.9846698,0.32842383,
216 0.41843212,0.48470423,0.6757128,0.95876855,0.5989163,0.13587572,0.72886884,0.88291156,
217 0.34402263,0.66211045,0.86188424,0.21498202,0.26397392,0.67372984,0.91386956,0.7339788,
218 0.91308993,0.1953016,0.1539217,0.214701,0.58234113,0.8019992,0.63969976,0.041050985,
219 0.7293308,0.26341477,0.54768014,0.97596467,0.12385198,0.44149798,0.5519762,0.1697347,
220 0.577215,0.8213594,0.47874716,0.64515114,0.61467725,0.18463866,0.23890929,0.51052976,
221 0.16807361,0.53142565,0.2414274,0.41690814,0.98815554,0.6245643,0.9477003,0.24780034,
222 0.82469565,0.8614785,0.9565832,0.062440686,0.9710724,0.039196696,0.11030199,0.35234734,
223 0.02065066,0.12832293,0.7328055,0.48924434,0.17247158,0.5769348,0.44146806,0.53575355,
224 0.17258933,0.6980237,0.86494404,0.50573164,0.5033998,0.71199447,0.41353586,0.26767612,
225 0.3789118,0.046621118,0.58491063,0.22861995,0.03134273,0.53280216,0.23382367,0.07748905,
226 0.96875405,0.6613716,0.64087844,0.8377165,0.051519375,0.68997836,0.3776376,0.43362603,
227 0.5358754,0.51419014,0.12823892,0.26574057,0.508808,0.15734084,0.78327274,0.5045347,
228 0.5445746,0.89297736,0.8531272,0.91270804,0.87429863,0.3965137,0.13544834,0.74269205,
229 0.80592203,0.045050766,0.13362087,0.17090783,0.02873757,0.99339336,0.6394376,0.48203012,
230 0.70598215,0.37082237,0.39792424,0.89938444,0.312602,0.48755112,0.18220617,0.17303479,
231 0.31954846,0.78080165,0.1755106,0.68262285,0.84665287,0.8520143,0.8459509,0.39417005,
232 0.30087698,0.81362164,0.61927587,0.32739028,0.9023775,0.27578092,0.6830477,0.15842387,
233 0.8473049,0.43057114,0.2019703,0.20560141,0.6237757,0.60283095,0.27645138,0.26605442,
234 0.27985683,0.41353813,0.85139906,0.71711886,0.5444832,0.73613757,0.7397004,0.7406752,
235 0.41016674,0.31896713,0.4541723,0.2795807,0.47941738,0.00504193,0.89091027,0.8097144,
236 0.63033766,0.37252298,0.9132861,0.5102532,0.04104481,0.30368647,0.21573475,0.99520445,
237 0.5047808,0.6868845,0.99881023,0.30377692,0.2554386,0.47201005,0.11120686,0.10077732,
238 0.1853349,0.49159425,0.3938629,0.8989509,0.9887155,0.698771,0.695701,0.78368753,0.52537227,
239 0.19451462,0.3659248,0.1968508,0.7751828,0.33103722,0.40406147,0.37832898,0.68663514,
240 0.32225925,0.41771907,0.034218453,0.42808908,0.20685343,0.1861495,0.045986768,0.8532299,
241 0.17200677,0.44670314,0.56831235,0.5388232,0.5430553,0.69175136,0.6462231,0.42827028,
242 0.10050113,0.30627027,0.9967943,0.6684778,0.5928422,0.63392985,0.99123496,0.79301435,
243 0.7936309,0.42839453,0.39781123,0.22329247,0.0122212395,0.2807108,0.19812097,0.5576105,
244 0.115653396,0.3732018,0.7622857,0.19847734,0.5310287,0.7298145,0.5518292,0.9117333,
245 0.13215758,0.33716795,0.42372775,0.6779287,0.35799992,0.097887225,0.20171605,0.9948177,
246 0.1829232,0.80349857,0.9807098,0.22959666,0.67322475,0.63094735,0.93454355,0.15962408,
247 0.04335433,0.47104993,0.36784375,0.45258796,0.93415564,0.1655446,0.7195017,0.76236975,
248 0.3846913,0.01330617,0.84716374,0.1227003,0.65102947,0.6632434,0.3728453,0.4222391,
249 0.6942989,0.16014872,0.10798196,0.94033676,0.026525471,0.8379024,0.5484514,0.13500613,
250 0.22919805,0.7001831,0.6573261,0.38086265,0.8725666,0.35077834,0.28415123,0.42283052,
251 0.668379,0.9769895,0.37621376,0.646407,0.11188069,0.17129017,0.7441628,0.25617477,
252 0.7751679,0.8565412,0.67631435,0.45213568,0.61896557,0.3387995,0.51607716,0.60779697,
253 0.16428445,0.5080923,0.13012086,0.61184275,0.7690249,0.9578811,0.67365676,0.16241212,
254 0.97157824,0.5595742,0.75936574,0.6043881,0.2149638,0.4925318,0.58727825,0.97953695,
255 0.01605968,0.2819307,0.6448378,0.4265335,0.661541,0.3976571,0.40607136,0.46425515,
256 0.2055872,0.2716193,0.4132582,0.8372537,0.37787434,0.082228854,0.7985557,0.9718134,
257 0.35222608,0.4853643,0.2569464,0.14783978,0.4889042,0.62900156,0.19994198,0.4618481,
258 0.21673755,0.51749533,0.1260157,0.83759904,0.36438805,0.6704668,0.22010763,0.2359318,
259 0.53004104,0.9723652,0.91218954,0.9153926,0.48207277,0.34850466,0.8939421
260 };
261 std::vector<std::vector<float>> input_vectors{input_zeros,
262 input_ones,
263 input_random};
264
265 std::vector<float> output_zeros(512, 0);
266 std::vector<float> output_ones(512, 0);
267 std::vector<float> output_random(512, 0);
268 std::vector<std::vector<float>> output_vectors{output_zeros,
269 output_ones,
270 output_random};
271
272 std::vector<float> expected_result_zeros(512, 0);
273 std::vector<float> expected_result_ones(512, 0);
274 expected_result_ones[0] = 512.0;
275
276 /* Values are stored as [real0, realN/2, real1, im1, real2, im2, ...] */
277 std::vector<float> expected_result_random{
278 2.59510161e+02, 2.59796867, 2.55982143, -5.91349888,
279 -1.80049237, 1.09902763, 4.00943240, 7.76684892,
280 4.32617219, -4.33636417, 2.98128463, -0.83763449,
281 2.92973078, -3.75655459, 2.27203161, -4.61106145,
282 5.55562176, -4.71880166, 2.13693416, -2.20496619,
283 2.31174036, -3.52991041, -6.16870680e-01, 2.43455407,
284 4.76317833, 8.66518565, 1.72350562, -1.33641312,
285 5.82836675, 0.89396187, 8.15031483, 4.34599034,
286 5.99780199, 2.94900065, -2.34462045e-02, -5.03789597,
287 1.21907020e+01, -6.47012928, 1.24434715, 0.23621713,
288 9.20921279e-01, -9.20510398, -7.73267254e-01, -3.72141078,
289 6.28883709, 4.00634065, 4.46491682, 0.74625307,
290 -5.87506158e-01, -4.22833058, 3.15189786, -1.82518672,
291 -6.93782260, 2.4170692, 3.23045185, 3.33383799,
292 5.10059531e-02, -3.4233929, -2.91651323, -0.0258584,
293 5.84499843, -9.51454903, -1.49214047e+01, 5.52200123,
294 4.52179590, -7.08268703, 5.16775420e-01, -2.90878759,
295 5.04314682, -1.16928599, -1.07329243e+01, -0.2719951,
296 -3.95269565, -2.32475678, -4.11031641, -2.20538835,
297 -5.89005095e-01, -5.65483456, -1.08927018e+01, 5.74801823,
298 -5.72520347, -3.94970165, -5.18407515e-01, 1.23622633,
299 9.56297959, 0.24424306, -3.74306351, 9.63476301,
300 -4.74493837, -0.35443496, 6.54760504, 1.16188913,
301 -1.33416950e+01, 7.19088609, -2.56045800, -0.49557866,
302 2.93460322, 6.91076746, -2.84779221e-01, 6.59958391,
303 1.70963995, -7.67293252, 4.18500790, -0.14627552,
304 -1.24855113, -1.43322867, 3.60644904e-01, -4.11521374,
305 -1.00628421e+01, 2.87531563, -9.34809732, -0.58251846,
306 3.61799848, -5.10288284, -5.96239076, 1.99792128,
307 -7.83229243e-02, 1.81741166, 2.32709681, 0.68487206,
308 -3.08398468, 1.5177629 , 1.41015388, 4.51146401,
309 1.90769911, 0.56093423, 3.58389141, -0.14974575,
310 -2.20163907, -1.62177814, -1.91904127, 1.94645907,
311 -1.37722930e-01, -4.30291678, -2.61843435, 0.12691241,
312 3.28959117, 4.7309582 , -2.93995652, 0.39835926,
313 2.89711768, 1.42284586, -5.13129145, -7.26477374,
314 3.74616158, -2.59659457, 3.85748750, -2.93737277,
315 3.17748694, -4.45041455, -2.68466437, 1.37377726,
316 1.60008368, 1.63787578, -1.95661401, -6.34937202,
317 -2.62744282, -5.20892662, 8.90553959e-01, -6.37113573,
318 2.35885332,-10.04547561, 3.29866159e-01, 1.89217741,
319 8.82516491e-01, -3.53298728, -2.22525608, -5.64794388,
320 -5.19226843, 2.5971315 , 4.49346648, -0.20428409,
321 -6.14851885,-11.90893875, 3.75899776, -1.86910056,
322 2.78518535, -2.6359501 , -2.13423317, 4.86509946,
323 -2.37625499, 7.42404308, 6.71175474, 6.06191618,
324 2.59014379, 4.76329698, 9.19140042, 9.69149015,
325 3.33307819, -4.03094924, 2.12988453, -0.15820258,
326 -1.03422801e+01, -3.04462388, 3.59852152, -2.00887343,
327 -3.69998656, 0.90050102, 6.79959099e-01, -1.88604949,
328 1.24235316, 0.41309537, 6.13876866, -7.1040085 ,
329 6.17728674, 1.91667103, -1.32895472, -0.17674661,
330 -6.94720428, 3.10502593, -2.33990738, -1.27840434,
331 3.21442520, 2.14102714, 2.37498837, 3.8158066 ,
332 -2.24107675, -5.52527559, -2.95697930, -0.50367608,
333 3.01687661, 7.08195792, 6.78604790, -3.94154162,
334 2.24402195, 4.60132638, 3.42211139, 4.17689039,
335 -1.17277194, 2.15404472, -2.37481930, 1.42611867,
336 -4.63033506e-01, 3.21563035, 1.38662123, 3.98598717,
337 -3.75283402, -2.47600433, -1.97290542, 2.83361487,
338 -8.45662834e-01, 5.57411581, -9.72981483e-01,-11.394208,
339 1.88220611, -0.80225125, -4.34295854e-01, -8.2954126,
340 3.81795409, -3.17146 , -4.61994107, -1.59820505,
341 -5.98834455, -4.93129451, -5.13862996e-01, -0.15649305,
342 -5.59094391, 6.25244435, -6.59974456,13.17193115,
343 4.48609092, 1.64741879, 7.40985006, 0.44896188,
344 3.81058449, -0.76425931, -5.47938416, 4.01447941,
345 -3.21535548, -1.45542238, 7.22740830e-01, -0.23983128,
346 -4.32373034, 0.1337671 , -5.89365226, -3.18756318,
347 7.90979161, 5.27570134, -3.43094553, -6.00826981,
348 1.17932561, -3.50027177, 1.81306385e-01, 1.1062498 ,
349 7.23650536e-01, -1.55500613, -3.88047911, -2.43746762,
350 -6.81565579, 2.16343352, 2.46366137, 2.38704469,
351 -2.55106395, 6.5091449 , -2.06510578,11.11320924,
352 2.06649835, -1.05026064, 1.63564303, -0.04638729,
353 1.45053876, 0.43730146, 1.25027939, 0.79932743,
354 2.81088838, 6.95136058, -4.41417255, 2.89610628,
355 1.15426258, -2.60704937, -2.77744882, 4.12872365,
356 -2.98288336, 6.75607352, 2.36553382, -2.10540332,
357 -7.30042988, -5.44897893, 3.44048454, 4.29726231,
358 2.18199500, -0.80126759, -4.04051175, 4.57584864,
359 9.56312116e-01, -4.45183318, 3.42348929, -9.84138181,
360 8.69604433, -6.6481311 , 4.68232735e-01, 1.41031176,
361 -1.24085700,10.61672181, 3.56591473e-01,10.51631421,
362 -9.97435470e-01, 2.72157537, 8.63583929, -2.19404252,
363 -1.53605811, -4.41068581, 2.05371873, 1.25665769,
364 1.65289503, 4.52520582, -5.35062642e-01, 0.82084677,
365 -1.10079476e+01, -5.09361474, 9.63129107, 3.90056638,
366 -4.19779738, 0.06565745, 2.42526917, 2.5854233 ,
367 -3.66709357, 3.80502971, -1.01489353e-01, -6.85423228,
368 -1.39361494e+01, -0.43904617, 6.01800968, -1.30751495,
369 -4.75122234, 2.74740671, 5.54971138, 9.43409003,
370 -9.94733058e-01, -3.0096825 , -8.60263376, 0.36653762,
371 3.53318614, -2.69194556, -8.95145740, -4.71570923,
372 5.15417709, 2.68645385, -2.78042293, 8.21739385,
373 5.90225003e-01, 2.13319153, 1.72158888, 0.18114627,
374 3.92269446, 3.3525857 , -3.40313825, 4.39280934,
375 -1.70368966, 1.29121245, -3.11326453, -1.85941318,
376 4.57078881, 0.72531039, 6.00445664, 4.9588524 ,
377 -3.32944491, -0.02080722, -7.42374632, -3.23290026,
378 -8.16145790e-01, 3.55935439, -6.19206533e-01, 2.42859073,
379 2.21486456, 3.76402487, 3.90930695, -3.61610186,
380 -8.12712547e-01,-14.63377988, 1.14460823, -3.14089899,
381 3.18097435, 1.21957751, 2.85181833, 0.89990235,
382 -4.32147361, -5.54219361, -1.12253677, 2.96141081,
383 -4.42577070, 3.17282306, 4.91746710, 1.16977744,
384 -4.55148089, -2.82520179, -1.71684103, 1.91487668,
385 7.70726836e-01, 0.78534837, -5.91048566, -4.8288477 ,
386 -1.35560162, -3.60938315, 1.15812301, 2.44299541,
387 1.36115190, -5.40950935, 7.08292127, 0.27720591,
388 -1.60210828e-01, 2.75862348, -1.57403782, 9.97207524,
389 -2.08957576, 8.70299964, -5.33004663, 4.1547783 ,
390 3.51580675, -5.10788085, 4.37938353, -3.73449894,
391 1.44673271, 0.51941469, 8.52232446e-01, -1.1134965 ,
392 -1.43972745, -1.62952127, -2.50759973, 1.19012213,
393 5.72772282e-01, -2.71833059, -6.84718990, 4.2621535 ,
394 1.58954734, -0.53827818, -1.44624396e-01, 7.63866979,
395 4.10423977e-01, -2.4785678 , -5.02681867, -2.03469811,
396 9.59505727e-01, 2.68589705, 3.20889444,-10.76452533,
397 -3.84771551, 2.49189796, -3.19895938, -3.49948794,
398 -2.67238970, 5.11386526, -3.85957031, -1.40741978,
399 1.76663166e-01,-11.7111276 , 6.39997364e-01, -1.30321198,
400 3.20767633, 1.65750671, -1.16187257e+01, -4.36634782,
401 -3.18675281, -4.89279155, -4.08760307, 2.19269283,
402 -1.58924870, 0.17948212, 4.81376107, 2.01871001,
403 -3.24211095e-01, -0.2790092 , 1.12603878, -3.61503491,
404 -2.86982317, -3.03634532, 8.07713910, 2.21302089,
405 2.91496011, -2.58564072, -2.45868054, -4.69447438,
406 -2.45868054, 4.69447438,
407 2.91496011, 2.58564072, 8.07713910, -2.21302089,
408 -2.86982317, 3.03634532, 1.12603878, 3.61503491,
409 -3.24211095e-01, 0.2790092 , 4.81376107, -2.01871001,
410 -1.58924870, -0.17948212, -4.08760307, -2.19269283,
411 -3.18675281, 4.89279155, -1.16187257e+01, 4.36634782,
412 3.20767633, -1.65750671, 6.39997364e-01, 1.30321198,
413 1.76663166e-01,11.7111276 , -3.85957031, 1.40741978,
414 -2.67238970, -5.11386526, -3.19895938, 3.49948794,
415 -3.84771551, -2.49189796, 3.20889444,10.76452533,
416 9.59505727e-01, -2.68589705, -5.02681867, 2.03469811,
417 4.10423977e-01, 2.4785678 , -1.44624396e-01, -7.63866979,
418 1.58954734, 0.53827818, -6.84718990, -4.2621535 ,
419 5.72772282e-01, 2.71833059, -2.50759973, -1.19012213,
420 -1.43972745, 1.62952127, 8.52232446e-01, 1.1134965 ,
421 1.44673271, -0.51941469, 4.37938353, 3.73449894,
422 3.51580675, 5.10788085, -5.33004663, -4.1547783 ,
423 -2.08957576, -8.70299964, -1.57403782, -9.97207524,
424 -1.60210828e-01, -2.75862348, 7.08292127, -0.27720591,
425 1.36115190, 5.40950935, 1.15812301, -2.44299541,
426 -1.35560162, 3.60938315, -5.91048566, 4.8288477 ,
427 7.70726836e-01, -0.78534837, -1.71684103, -1.91487668,
428 -4.55148089, 2.82520179, 4.91746710, -1.16977744,
429 -4.42577070, -3.17282306, -1.12253677, -2.96141081,
430 -4.32147361, 5.54219361, 2.85181833, -0.89990235,
431 3.18097435, -1.21957751, 1.14460823, 3.14089899,
432 -8.12712547e-01,14.63377988, 3.90930695, 3.61610186,
433 2.21486456, -3.76402487, -6.19206533e-01, -2.42859073,
434 -8.16145790e-01, -3.55935439, -7.42374632, 3.23290026,
435 -3.32944491, 0.02080722, 6.00445664, -4.9588524 ,
436 4.57078881, -0.72531039, -3.11326453, 1.85941318,
437 -1.70368966, -1.29121245, -3.40313825, -4.39280934,
438 3.92269446, -3.3525857 , 1.72158888, -0.18114627,
439 5.90225003e-01, -2.13319153, -2.78042293, -8.21739385,
440 5.15417709, -2.68645385, -8.95145740, 4.71570923,
441 3.53318614, 2.69194556, -8.60263376, -0.36653762,
442 -9.94733058e-01, 3.0096825 , 5.54971138, -9.43409003,
443 -4.75122234, -2.74740671, 6.01800968, 1.30751495,
444 -1.39361494e+01, 0.43904617, -1.01489353e-01, 6.85423228,
445 -3.66709357, -3.80502971, 2.42526917, -2.5854233 ,
446 -4.19779738, -0.06565745, 9.63129107, -3.90056638,
447 -1.10079476e+01, 5.09361474, -5.35062642e-01, -0.82084677,
448 1.65289503, -4.52520582, 2.05371873, -1.25665769,
449 -1.53605811, 4.41068581, 8.63583929, 2.19404252,
450 -9.97435470e-01, -2.72157537, 3.56591473e-01,-10.51631421,
451 -1.24085700,-10.61672181, 4.68232735e-01, -1.41031176,
452 8.69604433, 6.6481311 , 3.42348929, 9.84138181,
453 9.56312116e-01, 4.45183318, -4.04051175, -4.57584864,
454 2.18199500, 0.80126759, 3.44048454, -4.29726231,
455 -7.30042988, 5.44897893, 2.36553382, 2.10540332,
456 -2.98288336, -6.75607352, -2.77744882, -4.12872365,
457 1.15426258, 2.60704937, -4.41417255, -2.89610628,
458 2.81088838, -6.95136058, 1.25027939, -0.79932743,
459 1.45053876, -0.43730146, 1.63564303, 0.04638729,
460 2.06649835, 1.05026064, -2.06510578,-11.11320924,
461 -2.55106395, -6.5091449 , 2.46366137, -2.38704469,
462 -6.81565579, -2.16343352, -3.88047911, 2.43746762,
463 7.23650536e-01, 1.55500613, 1.81306385e-01, -1.1062498 ,
464 1.17932561, 3.50027177, -3.43094553, 6.00826981,
465 7.90979161, -5.27570134, -5.89365226, 3.18756318,
466 -4.32373034, -0.1337671 , 7.22740830e-01, 0.23983128,
467 -3.21535548, 1.45542238, -5.47938416, -4.01447941,
468 3.81058449, 0.76425931, 7.40985006, -0.44896188,
469 4.48609092, -1.64741879, -6.59974456,-13.17193115,
470 -5.59094391, -6.25244435, -5.13862996e-01, 0.15649305,
471 -5.98834455, 4.93129451, -4.61994107, 1.59820505,
472 3.81795409, 3.17146 , -4.34295854e-01, 8.2954126 ,
473 1.88220611, 0.80225125, -9.72981483e-01,11.394208 ,
474 -8.45662834e-01, -5.57411581, -1.97290542, -2.83361487,
475 -3.75283402, 2.47600433, 1.38662123, -3.98598717,
476 -4.63033506e-01, -3.21563035, -2.37481930, -1.42611867,
477 -1.17277194, -2.15404472, 3.42211139, -4.17689039,
478 2.24402195, -4.60132638, 6.78604790, 3.94154162,
479 3.01687661, -7.08195792, -2.95697930, 0.50367608,
480 -2.24107675, 5.52527559, 2.37498837, -3.8158066 ,
481 3.21442520, -2.14102714, -2.33990738, 1.27840434,
482 -6.94720428, -3.10502593, -1.32895472, 0.17674661,
483 6.17728674, -1.91667103, 6.13876866, 7.1040085 ,
484 1.24235316, -0.41309537, 6.79959099e-01, 1.88604949,
485 -3.69998656, -0.90050102, 3.59852152, 2.00887343,
486 -1.03422801e+01, 3.04462388, 2.12988453, 0.15820258,
487 3.33307819, 4.03094924, 9.19140042, -9.69149015,
488 2.59014379, -4.76329698, 6.71175474, -6.06191618,
489 -2.37625499, -7.42404308, -2.13423317, -4.86509946,
490 2.78518535, 2.6359501 , 3.75899776, 1.86910056,
491 -6.14851885,11.90893875, 4.49346648, 0.20428409,
492 -5.19226843, -2.5971315 , -2.22525608, 5.64794388,
493 8.82516491e-01, 3.53298728, 3.29866159e-01, -1.89217741,
494 2.35885332,10.04547561, 8.90553959e-01, 6.37113573,
495 -2.62744282, 5.20892662, -1.95661401, 6.34937202,
496 1.60008368, -1.63787578, -2.68466437, -1.37377726,
497 3.17748694, 4.45041455, 3.85748750, 2.93737277,
498 3.74616158, 2.59659457, -5.13129145, 7.26477374,
499 2.89711768, -1.42284586, -2.93995652, -0.39835926,
500 3.28959117, -4.7309582 , -2.61843435, -0.12691241,
501 -1.37722930e-01, 4.30291678, -1.91904127, -1.94645907,
502 -2.20163907, 1.62177814, 3.58389141, 0.14974575,
503 1.90769911, -0.56093423, 1.41015388, -4.51146401,
504 -3.08398468, -1.5177629 , 2.32709681, -0.68487206,
505 -7.83229243e-02, -1.81741166, -5.96239076, -1.99792128,
506 3.61799848, 5.10288284, -9.34809732, 0.58251846,
507 -1.00628421e+01, -2.87531563, 3.60644904e-01, 4.11521374,
508 -1.24855113, 1.43322867, 4.18500790, 0.14627552,
509 1.70963995, 7.67293252, -2.84779221e-01, -6.59958391,
510 2.93460322, -6.91076746, -2.56045800, 0.49557866,
511 -1.33416950e+01, -7.19088609, 6.54760504, -1.16188913,
512 -4.74493837, 0.35443496, -3.74306351, -9.63476301,
513 9.56297959, -0.24424306, -5.18407515e-01, -1.23622633,
514 -5.72520347, 3.94970165, -1.08927018e+01, -5.74801823,
515 -5.89005095e-01, 5.65483456, -4.11031641, 2.20538835,
516 -3.95269565, 2.32475678, -1.07329243e+01, 0.2719951 ,
517 5.04314682, 1.16928599, 5.16775420e-01, 2.90878759,
518 4.52179590, 7.08268703, -1.49214047e+01, -5.52200123,
519 5.84499843, 9.51454903, -2.91651323, 0.0258584 ,
520 5.10059531e-02, 3.4233929 , 3.23045185, -3.33383799,
521 -6.93782260, -2.4170692 , 3.15189786, 1.82518672,
522 -5.87506158e-01, 4.22833058, 4.46491682, -0.74625307,
523 6.28883709, -4.00634065, -7.73267254e-01, 3.72141078,
524 9.20921279e-01, 9.20510398, 1.24434715, -0.23621713,
525 1.21907020e+01, 6.47012928, -2.34462045e-02, 5.03789597,
526 5.99780199, -2.94900065, 8.15031483, -4.34599034,
527 5.82836675, -0.89396187, 1.72350562, 1.33641312,
528 4.76317833, -8.66518565, -6.16870680e-01, -2.43455407,
529 2.31174036, 3.52991041, 2.13693416, 2.20496619,
530 5.55562176, 4.71880166, 2.27203161, 4.61106145,
531 2.92973078, 3.75655459, 2.98128463, 0.83763449,
532 4.32617219, 4.33636417, 4.00943240, -7.76684892,
533 -1.80049237, -1.09902763, 2.55982143, 5.91349888
534 };
535
536 std::vector<std::vector<float>> expected_results_vectors{expected_result_zeros,
537 expected_result_ones,
538 expected_result_random};
539 arm::app::math::FftInstance fftInstance;
540 /* Iterate over each of the input vectors, calculate FFT and compare with corresponding expected_results vectors */
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000541 for (size_t j = 0; j < input_vectors.size(); j++) {
liabar01dee53bc2021-10-29 15:59:04 +0100542 uint16_t fftLen = input_vectors[j].size();
543 arm::app::math::MathUtils::FftInitF32(fftLen, fftInstance);
544 arm::app::math::MathUtils::FftF32(input_vectors[j], output_vectors[j], fftInstance);
545
546 float tolerance = 10e-4;
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000547 for (size_t i = 0; i < fftLen/2; i++) {
liabar01dee53bc2021-10-29 15:59:04 +0100548 CHECK (output_vectors[j][i] ==
549 Approx(expected_results_vectors[j][i]).margin(tolerance));
550
551 }
552 }
553}
554
555TEST_CASE("Test VecLogarithmF32")
556{
557 /*Test Constants: */
558 std::vector<float> input =
559 { 0.1e-10, 0.5, 1, M_PI, M_E };
560 std::vector<float> expectedResult =
561 {-25.328436, -0.693147181, 0, 1.144729886, 1};
562 std::vector<float> output(input.size());
563
564 arm::app::math::MathUtils::VecLogarithmF32(input,output);
565
566 for (size_t i = 0; i < input.size(); i++)
567 CHECK (expectedResult[i] == Approx(output[i]));
568}
569
570TEST_CASE("Test DotProductF32")
571{
572 /*Test Constants: */
573 std::vector<float> inputA
574 {1,1,1,0,0,0};
575 std::vector<float> inputB
576 {0,0,0,1,1,1};
577 uint32_t len = inputA.size();
578
579 float expectedResult = 0;
580 float dot_prod = arm::app::math::MathUtils::DotProductF32(inputA.data(), inputB.data(), len);
581 CHECK(dot_prod == expectedResult);
582}
583
584TEST_CASE("Test ComplexMagnitudeSquaredF32")
585{
586 /*Test Constants: */
587 std::vector<float> input
588 {0.0, 0.0, 0.5, 0.5,1,1};
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000589 size_t inputLen = input.size();
liabar01dee53bc2021-10-29 15:59:04 +0100590
591 std::vector<float> expectedResult
592 {0.0, 0.5, 2,};
Kshitij Sisodiab59ba682021-11-23 17:19:52 +0000593 size_t outputLen = inputLen/2;
liabar01dee53bc2021-10-29 15:59:04 +0100594 std::vector<float>output(outputLen);
595
596 /* Pass pointers to input/output vectors as this function over-writes the first half
597 * of the input vector with output results */
598 arm::app::math::MathUtils::ComplexMagnitudeSquaredF32(input.data(), inputLen, output.data(), outputLen);
599
Kshitij Sisodiab178b282022-01-04 13:37:53 +0000600 for (size_t i = 0; i < outputLen; i++) {
liabar01dee53bc2021-10-29 15:59:04 +0100601 CHECK (expectedResult[i] == Approx(output[i]));
Kshitij Sisodiab178b282022-01-04 13:37:53 +0000602 }
603}
604
605/**
606 * @brief Simple function to test the Softmax function
607 *
608 * @param input Input vector
609 * @param goldenOutput Expected output vector
610 */
611static void TestSoftmaxF32(
612 const std::vector<float>& input,
613 const std::vector<float>& goldenOutput)
614{
615 std::vector<float> output = input; /* Function modifies the vector in-place */
616 arm::app::math::MathUtils::SoftmaxF32(output);
617
618 for (size_t i = 0; i < goldenOutput.size(); ++i) {
619 CHECK(goldenOutput[i] == Approx(output[i]));
620 }
621
622 REQUIRE(output.size() == goldenOutput.size());
623}
624
625TEST_CASE("Test SoftmaxF32")
626{
627 SECTION("Simple series") {
628 const std::vector<float> input {
629 0.0, 1.0, 2.0, 3.0, 4.0,
630 5.0, 6.0, 7.0, 8.0, 9.0
631 };
632
633 const std::vector<float> expectedOutput {
634 7.80134161e-05, 2.12062451e-04,
635 5.76445508e-04, 1.56694135e-03,
636 4.25938820e-03, 1.15782175e-02,
637 3.14728583e-02, 8.55520989e-02,
638 2.32554716e-01, 6.32149258e-01
639 };
640
641 TestSoftmaxF32(input, expectedOutput);
642 }
643
644 SECTION("Random series") {
645 const std::vector<float> input {
646 0.8810943246170809, 0.5877587675947015,
647 0.6841546454788743, 0.4155920960071594,
648 0.9799415323651671, 0.5066432973545711,
649 0.3846024252355448, 0.4568689569632123,
650 0.3284413744557605, 0.49152323726213554
651 };
652
653 const std::vector<float> expectedOutput {
654 0.13329595, 0.09940837,
655 0.10946799, 0.08368583,
656 0.14714509, 0.09166319,
657 0.08113220, 0.08721240,
658 0.07670132, 0.09028766
659 };
660
661 TestSoftmaxF32(input, expectedOutput);
662 }
663
664 SECTION("Series with large STD") {
665 const std::vector<float> input {
666 0.001, 1000.000
667 };
668
669 const std::vector<float> expectedOutput {
670 0.000, 1.000
671 };
672
673 TestSoftmaxF32(input, expectedOutput);
674 }
liabar01dee53bc2021-10-29 15:59:04 +0100675}