blob: 8b73e4225fe846dbda9bc5c0c3d2ef246c1f9a89 [file] [log] [blame]
/*
* SPDX-FileCopyrightText: Copyright 2021 - 2023 Arm Limited and/or its affiliates
* <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "PlatformMath.hpp"
#include <catch.hpp>
#include <limits>
#include <numeric>
TEST_CASE("Test CosineF32")
{
/*Test Constants: */
std::vector<double> inputA{
0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1,
0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21,
0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32,
0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43,
0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54,
0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65,
0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76,
0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87,
0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98,
0.99, 1.0
};
std::vector<double> expectedResult{
1.0, 0.9995065603657316, 0.9980267284282716, 0.99556196460308,
0.9921147013144779, 0.9876883405951378, 0.9822872507286887,
0.9759167619387474, 0.9685831611286311, 0.9602936856769431,
0.9510565162951535, 0.9408807689542255, 0.9297764858882515,
0.9177546256839811, 0.9048270524660195, 0.8910065241883679,
0.8763066800438636, 0.8607420270039436, 0.8443279255020151,
0.8270805742745618, 0.8090169943749475, 0.7901550123756904,
0.7705132427757893, 0.7501110696304596, 0.7289686274214116,
0.7071067811865476, 0.6845471059286886, 0.6613118653236518,
0.6374239897486896, 0.6129070536529766, 0.5877852522924731,
0.5620833778521306, 0.5358267949789965, 0.5090414157503712,
0.48175367410171516, 0.4539904997395468, 0.42577929156507266,
0.39714789063478056, 0.3681245526846781, 0.3387379202452915,
0.30901699437494745, 0.2789911060392295, 0.24868988716485496,
0.2181432413965427, 0.18738131458572474, 0.15643446504023092,
0.12533323356430426, 0.0941083133185145, 0.06279051952931353,
0.031410759078128396, 6.123233995736766e-17, -0.03141075907812828,
-0.0627905195293134, -0.09410831331851438, -0.12533323356430437,
-0.15643446504023104, -0.18738131458572482, -0.21814324139654234,
-0.24868988716485463, -0.27899110603922916, -0.30901699437494734,
-0.33873792024529137, -0.368124552684678, -0.39714789063478045,
-0.4257792915650727, -0.4539904997395467, -0.48175367410171543,
-0.5090414157503713, -0.5358267949789969, -0.5620833778521304,
-0.587785252292473, -0.6129070536529763, -0.6374239897486897,
-0.6613118653236517, -0.6845471059286887, -0.7071067811865475,
-0.7289686274214113, -0.7501110696304596, -0.7705132427757891,
-0.7901550123756904, -0.8090169943749473, -0.8270805742745619,
-0.8443279255020149, -0.8607420270039435, -0.8763066800438634,
-0.8910065241883678, -0.9048270524660194, -0.9177546256839811,
-0.9297764858882513, -0.9408807689542255, -0.9510565162951535,
-0.9602936856769431, -0.9685831611286311, -0.9759167619387474,
-0.9822872507286886, -0.9876883405951377, -0.9921147013144778,
-0.99556196460308, -0.9980267284282716, -0.9995065603657316, -1.0
};
float tolerance = 10e-7;
for (size_t i = 0; i < inputA.size(); i++) {
CHECK (expectedResult[i] ==
Approx(arm::app::math::MathUtils::CosineF32(M_PI*inputA[i])).margin(tolerance));
}
}
TEST_CASE("Test SineF32")
{
/*Test Constants: */
std::vector<double> inputA{
0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1,
0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21,
0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32,
0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43,
0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54,
0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65,
0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76,
0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87,
0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98,
0.99, 1.0
};
std::vector<double> expectedResult{
0.0, 0.03141075907812829, 0.06279051952931337, 0.09410831331851431,
0.12533323356430426, 0.15643446504023087, 0.1873813145857246,
0.21814324139654256, 0.2486898871648548, 0.2789911060392293,
0.3090169943749474, 0.33873792024529137, 0.3681245526846779,
0.3971478906347806, 0.4257792915650727, 0.45399049973954675,
0.4817536741017153, 0.5090414157503713, 0.5358267949789967,
0.5620833778521306, 0.5877852522924731, 0.6129070536529764,
0.6374239897486896, 0.6613118653236518, 0.6845471059286886,
0.7071067811865475, 0.7289686274214116, 0.7501110696304596,
0.7705132427757893, 0.7901550123756903, 0.8090169943749475,
0.8270805742745618, 0.8443279255020151, 0.8607420270039436,
0.8763066800438637, 0.8910065241883678, 0.9048270524660196,
0.9177546256839811, 0.9297764858882513, 0.9408807689542255,
0.9510565162951535, 0.960293685676943, 0.9685831611286311,
0.9759167619387473, 0.9822872507286886, 0.9876883405951378,
0.9921147013144779, 0.99556196460308, 0.9980267284282716,
0.9995065603657316, 1.0, 0.9995065603657316,
0.9980267284282716, 0.99556196460308, 0.9921147013144778,
0.9876883405951377, 0.9822872507286886, 0.9759167619387474,
0.9685831611286312, 0.9602936856769431, 0.9510565162951536,
0.9408807689542255, 0.9297764858882513, 0.9177546256839813,
0.9048270524660195, 0.8910065241883679, 0.8763066800438635,
0.8607420270039436, 0.844327925502015, 0.827080574274562,
0.8090169943749475, 0.7901550123756905, 0.7705132427757893,
0.7501110696304597, 0.7289686274214114, 0.7071067811865476,
0.6845471059286888, 0.6613118653236518, 0.6374239897486899,
0.6129070536529764, 0.5877852522924732, 0.5620833778521305,
0.535826794978997, 0.5090414157503714, 0.4817536741017156,
0.45399049973954686, 0.4257792915650729, 0.3971478906347806,
0.36812455268467814, 0.3387379202452913, 0.3090169943749475,
0.2789911060392291, 0.24868988716485482, 0.21814324139654231,
0.18738131458572502, 0.15643446504023098, 0.12533323356430454,
0.09410831331851435, 0.06279051952931358, 0.031410759078128236,
1.2246467991473532e-16
};
float tolerance = 10e-4;
for (size_t i = 0; i < inputA.size(); i++) {
CHECK (expectedResult[i] ==
Approx(arm::app::math::MathUtils::SineF32(M_PI*inputA[i])).margin(tolerance));
}
}
TEST_CASE("Test SqrtF32")
{
/*Test Constants: */
std::vector<float> inputA{0,1,2,9,M_PI};
size_t len = inputA.size();
std::vector<float> expectedResult{0, 1, 1.414213562, 3, 1.772453851 };
for (size_t i=0; i < len; i++){
CHECK (expectedResult[i] == Approx(arm::app::math::MathUtils::SqrtF32(inputA[i])));
}
}
TEST_CASE("Test MeanF32")
{
/* Test Constants: */
std::vector<float> input
{0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 1.000};
/* Manually calculated mean of above vector */
float expectedResult = 0.100;
CHECK (expectedResult == Approx(arm::app::math::MathUtils::MeanF32(input.data(), input.size())));
/* Mean of 0 */
std::vector<float> input2{1, 2, -1, -2};
float expectedResult2 = 0.0f;
CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::MeanF32(input2.data(), input2.size())));
/* All 0s */
std::vector<float> input3 = std::vector<float>(9, 0);
float expectedResult3 = 0.0f;
CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::MeanF32(input3.data(), input3.size())));
/* All 1s */
std::vector<float> input4 = std::vector<float>(9, 1);
float expectedResult4 = 1.0f;
CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::MeanF32(input4.data(), input4.size())));
}
TEST_CASE("Test StdDevF32")
{
/*Test Constants: */
/* Normally distributed sample data generated by numpy normal library */
std::vector<float> input
{1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387,
1.74481176, -0.7612069, 0.3190391, -0.24937038, 1.46210794, -2.06014071,
-0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842,
0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434,
0.90085595, -0.68372786, -0.12289023, -0.93576943, -0.26788808, 0.53035547,
-0.69166075,-0.39675353, -0.6871727, -0.84520564, -0.67124613, -0.0126646,
-1.11731035, 0.2344157, 1.65980218, 0.74204416, -0.19183555, -0.88762896,
-0.74715829, 1.6924546, 0.05080775, -0.63699565, 0.19091548, 2.10025514,
0.12015895, 0.61720311
};
uint32_t inputLen = input.size();
/*Calculate mean using std library to avoid dependency on MathUtils::MeanF32 */
float mean = (std::accumulate(input.begin(), input.end(), 0.0f))/float(inputLen);
float output = arm::app::math::MathUtils::StdDevF32(input.data(), inputLen, mean);
/*Manually calculated standard deviation of above vector*/
float expectedResult = 0.969589282958136;
CHECK (expectedResult == Approx(output));
/* All 0s should have 0 std dev. */
std::vector<float> input2 = std::vector<float>(4, 0);
float expectedResult2 = 0.0f;
CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::StdDevF32(input2.data(), input2.size(), 0.0f)));
/* All 1s should have 0 std dev. */
std::vector<float> input3 = std::vector<float>(4, 1);
float expectedResult3 = 0.0f;
CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::StdDevF32(input3.data(), input3.size(), 1.0f)));
/* Manually calclualted std value */
std::vector<float> input4 {1, 2, 3, 4, 5, 6, 7, 8, 9, 0};
float mean2 = (std::accumulate(input4.begin(), input4.end(), 0.0f))/float(input4.size());
float expectedResult4 = 2.872281323;
CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::StdDevF32(input4.data(), input4.size(), mean2)));
}
TEST_CASE("Test FFT32")
{
constexpr size_t nElem = 512;
/*Test Constants: */
std::vector<float> input_zeros(nElem, 0);
std::vector<float> input_ones(nElem, 1);
/* Random numbers generated using numpy rand with range [0:1] */
std::vector<float> input_random{
0.42333686, 0.6547418, 0.8933691, 0.91466254, 0.5992143, 0.99474055, 0.97750413,
0.97160685, 0.72718734, 0.8699537, 0.643911, 0.09764466, 0.0050113136, 0.46823388,
0.13709934, 0.44892532, 0.59728205, 0.04055081, 0.5579888, 0.18445836, 0.66469765,
0.82715863, 0.91934484, 0.7844356, 0.23489648, 0.021708783, 0.67819905, 0.75761676,
0.48374954, 0.14006922, 0.87082034, 0.7694296, 0.80479276, 0.8241704, 0.95917296,
0.5758142, 0.16839339, 0.34290153, 0.5846108, 0.6878044, 0.1067114, 0.5198196,
0.4356897, 0.68049103, 0.12480807, 0.3538696, 0.06067087, 0.056964435, 0.5382167,
0.07761527, 0.6673144, 0.9045368, 0.11050189, 0.03530183, 0.07864744, 0.98752064,
0.014321936, 0.101833574, 0.43293256, 0.87102246, 0.52411795, 0.90232223, 0.49560344,
0.6803092, 0.2908511, 0.14653015, 0.99105513, 0.7057098, 0.09623502, 0.039713606,
0.88669086, 0.56018597, 0.90632766, 0.99241334, 0.18748309, 0.38991618, 0.6359827,
0.05665585, 0.732304, 0.2703365, 0.19014524, 0.5017947, 0.78862536, 0.81253093,
0.35050204, 0.2832596, 0.65221876, 0.59856164, 0.42758793, 0.78865635, 0.30943435,
0.93780816, 0.62568265, 0.35397422, 0.84209913, 0.48590583, 0.34837773, 0.5811646,
0.42924216, 0.26692122, 0.030709852, 0.84459823, 0.09085059, 0.29297647, 0.48539516,
0.33488297, 0.7877257, 0.8728821, 0.28454545, 0.7109578, 0.86097074, 0.8536262,
0.4978063, 0.5760398, 0.77506036, 0.7716988, 0.27041402, 0.52340513, 0.2055419,
0.8728235, 0.13492358, 0.79122984, 0.52998376, 0.33897072, 0.6426309, 0.8766521,
0.89287037, 0.74047667, 0.42341164, 0.67437655, 0.4682156, 0.67123246, 0.54287183,
0.3580476, 0.94756556, 0.2699457, 0.6131569, 0.75043845, 0.8115012, 0.49610943,
0.7108478, 0.90941435, 0.02233071, 0.37346774, 0.33732748, 0.46691266, 0.35784695,
0.39391598, 0.8556212, 0.884142, 0.11730601, 0.550112, 0.31513855, 0.69654715,
0.58585805, 0.4493127, 0.78515726, 0.8176612, 0.9846698, 0.32842383, 0.41843212,
0.48470423, 0.6757128, 0.95876855, 0.5989163, 0.13587572, 0.72886884, 0.88291156,
0.34402263, 0.66211045, 0.86188424, 0.21498202, 0.26397392, 0.67372984, 0.91386956,
0.7339788, 0.91308993, 0.1953016, 0.1539217, 0.214701, 0.58234113, 0.8019992,
0.63969976, 0.041050985, 0.7293308, 0.26341477, 0.54768014, 0.97596467, 0.12385198,
0.44149798, 0.5519762, 0.1697347, 0.577215, 0.8213594, 0.47874716, 0.64515114,
0.61467725, 0.18463866, 0.23890929, 0.51052976, 0.16807361, 0.53142565, 0.2414274,
0.41690814, 0.98815554, 0.6245643, 0.9477003, 0.24780034, 0.82469565, 0.8614785,
0.9565832, 0.062440686, 0.9710724, 0.039196696, 0.11030199, 0.35234734, 0.02065066,
0.12832293, 0.7328055, 0.48924434, 0.17247158, 0.5769348, 0.44146806, 0.53575355,
0.17258933, 0.6980237, 0.86494404, 0.50573164, 0.5033998, 0.71199447, 0.41353586,
0.26767612, 0.3789118, 0.046621118, 0.58491063, 0.22861995, 0.03134273, 0.53280216,
0.23382367, 0.07748905, 0.96875405, 0.6613716, 0.64087844, 0.8377165, 0.051519375,
0.68997836, 0.3776376, 0.43362603, 0.5358754, 0.51419014, 0.12823892, 0.26574057,
0.508808, 0.15734084, 0.78327274, 0.5045347, 0.5445746, 0.89297736, 0.8531272,
0.91270804, 0.87429863, 0.3965137, 0.13544834, 0.74269205, 0.80592203, 0.045050766,
0.13362087, 0.17090783, 0.02873757, 0.99339336, 0.6394376, 0.48203012, 0.70598215,
0.37082237, 0.39792424, 0.89938444, 0.312602, 0.48755112, 0.18220617, 0.17303479,
0.31954846, 0.78080165, 0.1755106, 0.68262285, 0.84665287, 0.8520143, 0.8459509,
0.39417005, 0.30087698, 0.81362164, 0.61927587, 0.32739028, 0.9023775, 0.27578092,
0.6830477, 0.15842387, 0.8473049, 0.43057114, 0.2019703, 0.20560141, 0.6237757,
0.60283095, 0.27645138, 0.26605442, 0.27985683, 0.41353813, 0.85139906, 0.71711886,
0.5444832, 0.73613757, 0.7397004, 0.7406752, 0.41016674, 0.31896713, 0.4541723,
0.2795807, 0.47941738, 0.00504193, 0.89091027, 0.8097144, 0.63033766, 0.37252298,
0.9132861, 0.5102532, 0.04104481, 0.30368647, 0.21573475, 0.99520445, 0.5047808,
0.6868845, 0.99881023, 0.30377692, 0.2554386, 0.47201005, 0.11120686, 0.10077732,
0.1853349, 0.49159425, 0.3938629, 0.8989509, 0.9887155, 0.698771, 0.695701,
0.78368753, 0.52537227, 0.19451462, 0.3659248, 0.1968508, 0.7751828, 0.33103722,
0.40406147, 0.37832898, 0.68663514, 0.32225925, 0.41771907, 0.034218453, 0.42808908,
0.20685343, 0.1861495, 0.045986768, 0.8532299, 0.17200677, 0.44670314, 0.56831235,
0.5388232, 0.5430553, 0.69175136, 0.6462231, 0.42827028, 0.10050113, 0.30627027,
0.9967943, 0.6684778, 0.5928422, 0.63392985, 0.99123496, 0.79301435, 0.7936309,
0.42839453, 0.39781123, 0.22329247, 0.0122212395, 0.2807108, 0.19812097, 0.5576105,
0.115653396, 0.3732018, 0.7622857, 0.19847734, 0.5310287, 0.7298145, 0.5518292,
0.9117333, 0.13215758, 0.33716795, 0.42372775, 0.6779287, 0.35799992, 0.097887225,
0.20171605, 0.9948177, 0.1829232, 0.80349857, 0.9807098, 0.22959666, 0.67322475,
0.63094735, 0.93454355, 0.15962408, 0.04335433, 0.47104993, 0.36784375, 0.45258796,
0.93415564, 0.1655446, 0.7195017, 0.76236975, 0.3846913, 0.01330617, 0.84716374,
0.1227003, 0.65102947, 0.6632434, 0.3728453, 0.4222391, 0.6942989, 0.16014872,
0.10798196, 0.94033676, 0.026525471, 0.8379024, 0.5484514, 0.13500613, 0.22919805,
0.7001831, 0.6573261, 0.38086265, 0.8725666, 0.35077834, 0.28415123, 0.42283052,
0.668379, 0.9769895, 0.37621376, 0.646407, 0.11188069, 0.17129017, 0.7441628,
0.25617477, 0.7751679, 0.8565412, 0.67631435, 0.45213568, 0.61896557, 0.3387995,
0.51607716, 0.60779697, 0.16428445, 0.5080923, 0.13012086, 0.61184275, 0.7690249,
0.9578811, 0.67365676, 0.16241212, 0.97157824, 0.5595742, 0.75936574, 0.6043881,
0.2149638, 0.4925318, 0.58727825, 0.97953695, 0.01605968, 0.2819307, 0.6448378,
0.4265335, 0.661541, 0.3976571, 0.40607136, 0.46425515, 0.2055872, 0.2716193,
0.4132582, 0.8372537, 0.37787434, 0.082228854, 0.7985557, 0.9718134, 0.35222608,
0.4853643, 0.2569464, 0.14783978, 0.4889042, 0.62900156, 0.19994198, 0.4618481,
0.21673755, 0.51749533, 0.1260157, 0.83759904, 0.36438805, 0.6704668, 0.22010763,
0.2359318, 0.53004104, 0.9723652, 0.91218954, 0.9153926, 0.48207277, 0.34850466,
0.8939421};
std::vector<std::vector<float>> input_vectors{input_zeros,
input_ones,
input_random};
std::vector<float> output_zeros(nElem, 0);
std::vector<float> output_ones(nElem, 0);
std::vector<float> output_random(nElem, 0);
std::vector<std::vector<float>> output_vectors{output_zeros,
output_ones,
output_random};
std::vector<float> expected_result_zeros(nElem, 0);
std::vector<float> expected_result_ones(nElem, 0);
expected_result_ones[0] = static_cast<float>(nElem);
/* Values are stored as [real0, realN/2, real1, im1, real2, im2, ...] */
std::vector<float> expected_result_random{
259.510161, 2.59796867, 2.55982143, -5.91349888, -1.80049237, 1.09902763,
4.0094324, 7.76684892, 4.32617219, -4.33636417, 2.98128463, -0.83763449,
2.92973078, -3.75655459, 2.27203161, -4.61106145, 5.55562176, -4.71880166,
2.13693416, -2.20496619, 2.31174036, -3.52991041, -0.61687068, 2.43455407,
4.76317833, 8.66518565, 1.72350562, -1.33641312, 5.82836675, 0.89396187,
8.15031483, 4.34599034, 5.99780199, 2.94900065, -0.0234462045, -5.03789597,
12.190702, -6.47012928, 1.24434715, 0.23621713, 0.920921279, -9.20510398,
-0.773267254, -3.72141078, 6.28883709, 4.00634065, 4.46491682, 0.74625307,
-0.587506158, -4.22833058, 3.15189786, -1.82518672, -6.9378226, 2.4170692,
3.23045185, 3.33383799, 0.0510059531, -3.4233929, -2.91651323, -0.0258584,
5.84499843, -9.51454903, -14.9214047, 5.52200123, 4.5217959, -7.08268703,
0.51677542, -2.90878759, 5.04314682, -1.16928599, -10.7329243, -0.2719951,
-3.95269565, -2.32475678, -4.11031641, -2.20538835, -0.589005095, -5.65483456,
-10.8927018, 5.74801823, -5.72520347, -3.94970165, -0.518407515, 1.23622633,
9.56297959, 0.24424306, -3.74306351, 9.63476301, -4.74493837, -0.35443496,
6.54760504, 1.16188913, -13.341695, 7.19088609, -2.560458, -0.49557866,
2.93460322, 6.91076746, -0.284779221, 6.59958391, 1.70963995, -7.67293252,
4.1850079, -0.14627552, -1.24855113, -1.43322867, 0.360644904, -4.11521374,
-10.0628421, 2.87531563, -9.34809732, -0.58251846, 3.61799848, -5.10288284,
-5.96239076, 1.99792128, -0.0783229243, 1.81741166, 2.32709681, 0.68487206,
-3.08398468, 1.5177629, 1.41015388, 4.51146401, 1.90769911, 0.56093423,
3.58389141, -0.14974575, -2.20163907, -1.62177814, -1.91904127, 1.94645907,
-0.13772293, -4.30291678, -2.61843435, 0.12691241, 3.28959117, 4.7309582,
-2.93995652, 0.39835926, 2.89711768, 1.42284586, -5.13129145, -7.26477374,
3.74616158, -2.59659457, 3.8574875, -2.93737277, 3.17748694, -4.45041455,
-2.68466437, 1.37377726, 1.60008368, 1.63787578, -1.95661401, -6.34937202,
-2.62744282, -5.20892662, 0.890553959, -6.37113573, 2.35885332, -10.04547561,
0.329866159, 1.89217741, 0.882516491, -3.53298728, -2.22525608, -5.64794388,
-5.19226843, 2.5971315, 4.49346648, -0.20428409, -6.14851885, -11.90893875,
3.75899776, -1.86910056, 2.78518535, -2.6359501, -2.13423317, 4.86509946,
-2.37625499, 7.42404308, 6.71175474, 6.06191618, 2.59014379, 4.76329698,
9.19140042, 9.69149015, 3.33307819, -4.03094924, 2.12988453, -0.15820258,
-10.3422801, -3.04462388, 3.59852152, -2.00887343, -3.69998656, 0.90050102,
0.679959099, -1.88604949, 1.24235316, 0.41309537, 6.13876866, -7.1040085,
6.17728674, 1.91667103, -1.32895472, -0.17674661, -6.94720428, 3.10502593,
-2.33990738, -1.27840434, 3.2144252, 2.14102714, 2.37498837, 3.8158066,
-2.24107675, -5.52527559, -2.9569793, -0.50367608, 3.01687661, 7.08195792,
6.7860479, -3.94154162, 2.24402195, 4.60132638, 3.42211139, 4.17689039,
-1.17277194, 2.15404472, -2.3748193, 1.42611867, -0.463033506, 3.21563035,
1.38662123, 3.98598717, -3.75283402, -2.47600433, -1.97290542, 2.83361487,
-0.845662834, 5.57411581, -0.972981483, -11.394208, 1.88220611, -0.80225125,
-0.434295854, -8.2954126, 3.81795409, -3.17146, -4.61994107, -1.59820505,
-5.98834455, -4.93129451, -0.513862996, -0.15649305, -5.59094391, 6.25244435,
-6.59974456, 13.17193115, 4.48609092, 1.64741879, 7.40985006, 0.44896188,
3.81058449, -0.76425931, -5.47938416, 4.01447941, -3.21535548, -1.45542238,
0.72274083, -0.23983128, -4.32373034, 0.1337671, -5.89365226, -3.18756318,
7.90979161, 5.27570134, -3.43094553, -6.00826981, 1.17932561, -3.50027177,
0.181306385, 1.1062498, 0.723650536, -1.55500613, -3.88047911, -2.43746762,
-6.81565579, 2.16343352, 2.46366137, 2.38704469, -2.55106395, 6.5091449,
-2.06510578, 11.11320924, 2.06649835, -1.05026064, 1.63564303, -0.04638729,
1.45053876, 0.43730146, 1.25027939, 0.79932743, 2.81088838, 6.95136058,
-4.41417255, 2.89610628, 1.15426258, -2.60704937, -2.77744882, 4.12872365,
-2.98288336, 6.75607352, 2.36553382, -2.10540332, -7.30042988, -5.44897893,
3.44048454, 4.29726231, 2.181995, -0.80126759, -4.04051175, 4.57584864,
0.956312116, -4.45183318, 3.42348929, -9.84138181, 8.69604433, -6.6481311,
0.468232735, 1.41031176, -1.240857, 10.61672181, 0.356591473, 10.51631421,
-0.99743547, 2.72157537, 8.63583929, -2.19404252, -1.53605811, -4.41068581,
2.05371873, 1.25665769, 1.65289503, 4.52520582, -0.535062642, 0.82084677,
-11.0079476, -5.09361474, 9.63129107, 3.90056638, -4.19779738, 0.06565745,
2.42526917, 2.5854233, -3.66709357, 3.80502971, -0.101489353, -6.85423228,
-13.9361494, -0.43904617, 6.01800968, -1.30751495, -4.75122234, 2.74740671,
5.54971138, 9.43409003, -0.994733058, -3.0096825, -8.60263376, 0.36653762,
3.53318614, -2.69194556, -8.9514574, -4.71570923, 5.15417709, 2.68645385,
-2.78042293, 8.21739385, 0.590225003, 2.13319153, 1.72158888, 0.18114627,
3.92269446, 3.3525857, -3.40313825, 4.39280934, -1.70368966, 1.29121245,
-3.11326453, -1.85941318, 4.57078881, 0.72531039, 6.00445664, 4.9588524,
-3.32944491, -0.02080722, -7.42374632, -3.23290026, -0.81614579, 3.55935439,
-0.619206533, 2.42859073, 2.21486456, 3.76402487, 3.90930695, -3.61610186,
-0.812712547, -14.63377988, 1.14460823, -3.14089899, 3.18097435, 1.21957751,
2.85181833, 0.89990235, -4.32147361, -5.54219361, -1.12253677, 2.96141081,
-4.4257707, 3.17282306, 4.9174671, 1.16977744, -4.55148089, -2.82520179,
-1.71684103, 1.91487668, 0.770726836, 0.78534837, -5.91048566, -4.8288477,
-1.35560162, -3.60938315, 1.15812301, 2.44299541, 1.3611519, -5.40950935,
7.08292127, 0.27720591, -0.160210828, 2.75862348, -1.57403782, 9.97207524,
-2.08957576, 8.70299964, -5.33004663, 4.1547783, 3.51580675, -5.10788085,
4.37938353, -3.73449894, 1.44673271, 0.51941469, 0.852232446, -1.1134965,
-1.43972745, -1.62952127, -2.50759973, 1.19012213, 0.572772282, -2.71833059,
-6.8471899, 4.2621535, 1.58954734, -0.53827818, -0.144624396, 7.63866979,
0.410423977, -2.4785678, -5.02681867, -2.03469811, 0.959505727, 2.68589705,
3.20889444, -10.76452533, -3.84771551, 2.49189796, -3.19895938, -3.49948794,
-2.6723897, 5.11386526, -3.85957031, -1.40741978, 0.176663166, -11.7111276,
0.639997364, -1.30321198, 3.20767633, 1.65750671, -11.6187257, -4.36634782,
-3.18675281, -4.89279155, -4.08760307, 2.19269283, -1.5892487, 0.17948212,
4.81376107, 2.01871001, -0.324211095, -0.2790092, 1.12603878, -3.61503491,
-2.86982317, -3.03634532, 8.0771391, 2.21302089, 2.91496011, -2.58564072,
0.0, 0.0};
std::vector<std::vector<float>> expected_results_vectors{
expected_result_zeros, expected_result_ones, expected_result_random};
arm::app::math::FftInstance fftInstance;
/* Iterate over each of the input vectors, calculate FFT and compare with corresponding expected_results vectors */
for (size_t j = 0; j < input_vectors.size(); j++) {
uint16_t fftLen = input_vectors[j].size();
arm::app::math::MathUtils::FftInitF32(fftLen, fftInstance);
arm::app::math::MathUtils::FftF32(input_vectors[j], output_vectors[j], fftInstance);
const float tolerance = 10e-4;
for (size_t i = 0; i < fftLen/2; i++) {
CHECK(output_vectors[j][i] == Approx(expected_results_vectors[j][i]).margin(tolerance));
}
}
/* Test inverse FFTs using the forward FFT for complex numbers.
* IFFT(XVec) = (1/N)(Conj(FFT(Conj(Xvec)))) */
for (size_t j = 0; j < input_vectors.size(); j++) {
const uint16_t fftLen = input_vectors[j].size();
const size_t inputSz = fftLen * 2;
/* This vector will populate the input for FFT for complex numbers. */
std::vector<float> inputWithConjugates(inputSz);
/* We expect the output of this test to return the original input. */
std::vector<float> expectedOutputVector = input_vectors[j];
/* Placeholder for output vector. */
std::vector<float> outputVector(inputWithConjugates.size());
/* Populate the 0 and N/2 elements (these will be real numbers
* only - no imaginary parts. */
inputWithConjugates[0] = expected_results_vectors[j][0];
inputWithConjugates[fftLen] = expected_results_vectors[j][1];
/* Populate the rest of the elements - conjugates of the original for the left mirror
* and the right side with what the left would have been, i.e.,
* conjugate(conjugate(X_left)). */
for (size_t i = 2; i < fftLen; i += 2) {
inputWithConjugates[i] = expected_results_vectors[j][i];
inputWithConjugates[i + 1] = 0 - expected_results_vectors[j][i + 1];
inputWithConjugates[fftLen + i] = expected_results_vectors[j][fftLen - i];
inputWithConjugates[fftLen + i + 1] = expected_results_vectors[j][fftLen - i + 1];
}
arm::app::math::MathUtils::FftInitF32(
fftLen, fftInstance, arm::app::math::FftType::complex);
arm::app::math::MathUtils::FftF32(inputWithConjugates, outputVector, fftInstance);
const float tolerance = 0.1;
for (size_t i = 0; i < expectedOutputVector.size(); i++) {
/* The number returned here will be nElem times the output. */
CHECK(outputVector[i * 2] / static_cast<float>(nElem) ==
Approx(expectedOutputVector[i]).margin(tolerance));
/* The imaginary part here should be close to 0 as the original input
* we supplied was real. */
CHECK(outputVector[i * 2 + 1] / nElem == Approx(0.f).margin(tolerance));
}
}
}
TEST_CASE("Test VecLogarithmF32")
{
/*Test Constants: */
std::vector<float> input =
{ 0.1e-10, 0.5, 1, M_PI, M_E };
std::vector<float> expectedResult =
{-25.328436, -0.693147181, 0, 1.144729886, 1};
std::vector<float> output(input.size());
arm::app::math::MathUtils::VecLogarithmF32(input,output);
for (size_t i = 0; i < input.size(); i++)
CHECK (expectedResult[i] == Approx(output[i]));
}
TEST_CASE("Test DotProductF32")
{
/*Test Constants: */
std::vector<float> inputA
{1,1,1,0,0,0};
std::vector<float> inputB
{0,0,0,1,1,1};
uint32_t len = inputA.size();
float expectedResult = 0;
float dot_prod = arm::app::math::MathUtils::DotProductF32(inputA.data(), inputB.data(), len);
CHECK(dot_prod == expectedResult);
}
TEST_CASE("Test ComplexMagnitudeSquaredF32")
{
/*Test Constants: */
std::vector<float> input
{0.0, 0.0, 0.5, 0.5,1,1};
size_t inputLen = input.size();
std::vector<float> expectedResult
{0.0, 0.5, 2,};
size_t outputLen = inputLen/2;
std::vector<float>output(outputLen);
/* Pass pointers to input/output vectors as this function over-writes the first half
* of the input vector with output results */
arm::app::math::MathUtils::ComplexMagnitudeSquaredF32(input.data(), inputLen, output.data(), outputLen);
for (size_t i = 0; i < outputLen; i++) {
CHECK (expectedResult[i] == Approx(output[i]));
}
}
/**
* @brief Simple function to test the Softmax function
*
* @param input Input vector
* @param goldenOutput Expected output vector
*/
static void TestSoftmaxF32(
const std::vector<float>& input,
const std::vector<float>& goldenOutput)
{
std::vector<float> output = input; /* Function modifies the vector in-place */
arm::app::math::MathUtils::SoftmaxF32(output);
for (size_t i = 0; i < goldenOutput.size(); ++i) {
CHECK(goldenOutput[i] == Approx(output[i]));
}
REQUIRE(output.size() == goldenOutput.size());
}
TEST_CASE("Test SoftmaxF32")
{
SECTION("Simple series") {
const std::vector<float> input {
0.0, 1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0, 9.0
};
const std::vector<float> expectedOutput {
7.80134161e-05, 2.12062451e-04,
5.76445508e-04, 1.56694135e-03,
4.25938820e-03, 1.15782175e-02,
3.14728583e-02, 8.55520989e-02,
2.32554716e-01, 6.32149258e-01
};
TestSoftmaxF32(input, expectedOutput);
}
SECTION("Random series") {
const std::vector<float> input {
0.8810943246170809, 0.5877587675947015,
0.6841546454788743, 0.4155920960071594,
0.9799415323651671, 0.5066432973545711,
0.3846024252355448, 0.4568689569632123,
0.3284413744557605, 0.49152323726213554
};
const std::vector<float> expectedOutput {
0.13329595, 0.09940837,
0.10946799, 0.08368583,
0.14714509, 0.09166319,
0.08113220, 0.08721240,
0.07670132, 0.09028766
};
TestSoftmaxF32(input, expectedOutput);
}
SECTION("Series with large STD") {
const std::vector<float> input {
0.001, 1000.000
};
const std::vector<float> expectedOutput {
0.000, 1.000
};
TestSoftmaxF32(input, expectedOutput);
}
}