liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 1 | /* |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 2 | * SPDX-FileCopyrightText: Copyright 2021 - 2023 Arm Limited and/or its affiliates |
| 3 | * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0 |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 4 | * |
| 5 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | * you may not use this file except in compliance with the License. |
| 7 | * You may obtain a copy of the License at |
| 8 | * |
| 9 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | * |
| 11 | * Unless required by applicable law or agreed to in writing, software |
| 12 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | * See the License for the specific language governing permissions and |
| 15 | * limitations under the License. |
| 16 | */ |
| 17 | #include "PlatformMath.hpp" |
| 18 | #include <catch.hpp> |
| 19 | #include <limits> |
| 20 | #include <numeric> |
| 21 | |
| 22 | TEST_CASE("Test CosineF32") |
| 23 | { |
| 24 | /*Test Constants: */ |
| 25 | std::vector<double> inputA{ |
| 26 | 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, |
| 27 | 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, |
| 28 | 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, |
| 29 | 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43, |
| 30 | 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, |
| 31 | 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, |
| 32 | 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, |
| 33 | 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, |
| 34 | 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, |
| 35 | 0.99, 1.0 |
| 36 | }; |
| 37 | std::vector<double> expectedResult{ |
| 38 | 1.0, 0.9995065603657316, 0.9980267284282716, 0.99556196460308, |
| 39 | 0.9921147013144779, 0.9876883405951378, 0.9822872507286887, |
| 40 | 0.9759167619387474, 0.9685831611286311, 0.9602936856769431, |
| 41 | 0.9510565162951535, 0.9408807689542255, 0.9297764858882515, |
| 42 | 0.9177546256839811, 0.9048270524660195, 0.8910065241883679, |
| 43 | 0.8763066800438636, 0.8607420270039436, 0.8443279255020151, |
| 44 | 0.8270805742745618, 0.8090169943749475, 0.7901550123756904, |
| 45 | 0.7705132427757893, 0.7501110696304596, 0.7289686274214116, |
| 46 | 0.7071067811865476, 0.6845471059286886, 0.6613118653236518, |
| 47 | 0.6374239897486896, 0.6129070536529766, 0.5877852522924731, |
| 48 | 0.5620833778521306, 0.5358267949789965, 0.5090414157503712, |
| 49 | 0.48175367410171516, 0.4539904997395468, 0.42577929156507266, |
| 50 | 0.39714789063478056, 0.3681245526846781, 0.3387379202452915, |
| 51 | 0.30901699437494745, 0.2789911060392295, 0.24868988716485496, |
| 52 | 0.2181432413965427, 0.18738131458572474, 0.15643446504023092, |
| 53 | 0.12533323356430426, 0.0941083133185145, 0.06279051952931353, |
| 54 | 0.031410759078128396, 6.123233995736766e-17, -0.03141075907812828, |
| 55 | -0.0627905195293134, -0.09410831331851438, -0.12533323356430437, |
| 56 | -0.15643446504023104, -0.18738131458572482, -0.21814324139654234, |
| 57 | -0.24868988716485463, -0.27899110603922916, -0.30901699437494734, |
| 58 | -0.33873792024529137, -0.368124552684678, -0.39714789063478045, |
| 59 | -0.4257792915650727, -0.4539904997395467, -0.48175367410171543, |
| 60 | -0.5090414157503713, -0.5358267949789969, -0.5620833778521304, |
| 61 | -0.587785252292473, -0.6129070536529763, -0.6374239897486897, |
| 62 | -0.6613118653236517, -0.6845471059286887, -0.7071067811865475, |
| 63 | -0.7289686274214113, -0.7501110696304596, -0.7705132427757891, |
| 64 | -0.7901550123756904, -0.8090169943749473, -0.8270805742745619, |
| 65 | -0.8443279255020149, -0.8607420270039435, -0.8763066800438634, |
| 66 | -0.8910065241883678, -0.9048270524660194, -0.9177546256839811, |
| 67 | -0.9297764858882513, -0.9408807689542255, -0.9510565162951535, |
| 68 | -0.9602936856769431, -0.9685831611286311, -0.9759167619387474, |
| 69 | -0.9822872507286886, -0.9876883405951377, -0.9921147013144778, |
| 70 | -0.99556196460308, -0.9980267284282716, -0.9995065603657316, -1.0 |
| 71 | }; |
| 72 | |
| 73 | float tolerance = 10e-7; |
| 74 | for (size_t i = 0; i < inputA.size(); i++) { |
| 75 | CHECK (expectedResult[i] == |
| 76 | Approx(arm::app::math::MathUtils::CosineF32(M_PI*inputA[i])).margin(tolerance)); |
| 77 | } |
| 78 | } |
| 79 | |
| 80 | TEST_CASE("Test SineF32") |
| 81 | { |
| 82 | /*Test Constants: */ |
| 83 | std::vector<double> inputA{ |
| 84 | 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, |
| 85 | 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, |
| 86 | 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, |
| 87 | 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43, |
| 88 | 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, |
| 89 | 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, |
| 90 | 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, |
| 91 | 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, |
| 92 | 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, |
| 93 | 0.99, 1.0 |
| 94 | }; |
| 95 | std::vector<double> expectedResult{ |
| 96 | |
| 97 | 0.0, 0.03141075907812829, 0.06279051952931337, 0.09410831331851431, |
| 98 | 0.12533323356430426, 0.15643446504023087, 0.1873813145857246, |
| 99 | 0.21814324139654256, 0.2486898871648548, 0.2789911060392293, |
| 100 | 0.3090169943749474, 0.33873792024529137, 0.3681245526846779, |
| 101 | 0.3971478906347806, 0.4257792915650727, 0.45399049973954675, |
| 102 | 0.4817536741017153, 0.5090414157503713, 0.5358267949789967, |
| 103 | 0.5620833778521306, 0.5877852522924731, 0.6129070536529764, |
| 104 | 0.6374239897486896, 0.6613118653236518, 0.6845471059286886, |
| 105 | 0.7071067811865475, 0.7289686274214116, 0.7501110696304596, |
| 106 | 0.7705132427757893, 0.7901550123756903, 0.8090169943749475, |
| 107 | 0.8270805742745618, 0.8443279255020151, 0.8607420270039436, |
| 108 | 0.8763066800438637, 0.8910065241883678, 0.9048270524660196, |
| 109 | 0.9177546256839811, 0.9297764858882513, 0.9408807689542255, |
| 110 | 0.9510565162951535, 0.960293685676943, 0.9685831611286311, |
| 111 | 0.9759167619387473, 0.9822872507286886, 0.9876883405951378, |
| 112 | 0.9921147013144779, 0.99556196460308, 0.9980267284282716, |
| 113 | 0.9995065603657316, 1.0, 0.9995065603657316, |
| 114 | 0.9980267284282716, 0.99556196460308, 0.9921147013144778, |
| 115 | 0.9876883405951377, 0.9822872507286886, 0.9759167619387474, |
| 116 | 0.9685831611286312, 0.9602936856769431, 0.9510565162951536, |
| 117 | 0.9408807689542255, 0.9297764858882513, 0.9177546256839813, |
| 118 | 0.9048270524660195, 0.8910065241883679, 0.8763066800438635, |
| 119 | 0.8607420270039436, 0.844327925502015, 0.827080574274562, |
| 120 | 0.8090169943749475, 0.7901550123756905, 0.7705132427757893, |
| 121 | 0.7501110696304597, 0.7289686274214114, 0.7071067811865476, |
| 122 | 0.6845471059286888, 0.6613118653236518, 0.6374239897486899, |
| 123 | 0.6129070536529764, 0.5877852522924732, 0.5620833778521305, |
| 124 | 0.535826794978997, 0.5090414157503714, 0.4817536741017156, |
| 125 | 0.45399049973954686, 0.4257792915650729, 0.3971478906347806, |
| 126 | 0.36812455268467814, 0.3387379202452913, 0.3090169943749475, |
| 127 | 0.2789911060392291, 0.24868988716485482, 0.21814324139654231, |
| 128 | 0.18738131458572502, 0.15643446504023098, 0.12533323356430454, |
| 129 | 0.09410831331851435, 0.06279051952931358, 0.031410759078128236, |
| 130 | 1.2246467991473532e-16 |
| 131 | }; |
| 132 | float tolerance = 10e-4; |
| 133 | for (size_t i = 0; i < inputA.size(); i++) { |
| 134 | CHECK (expectedResult[i] == |
| 135 | Approx(arm::app::math::MathUtils::SineF32(M_PI*inputA[i])).margin(tolerance)); |
| 136 | } |
| 137 | } |
| 138 | |
| 139 | TEST_CASE("Test SqrtF32") |
| 140 | { |
| 141 | /*Test Constants: */ |
| 142 | std::vector<float> inputA{0,1,2,9,M_PI}; |
Kshitij Sisodia | b59ba68 | 2021-11-23 17:19:52 +0000 | [diff] [blame] | 143 | size_t len = inputA.size(); |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 144 | std::vector<float> expectedResult{0, 1, 1.414213562, 3, 1.772453851 }; |
| 145 | |
| 146 | for (size_t i=0; i < len; i++){ |
| 147 | CHECK (expectedResult[i] == Approx(arm::app::math::MathUtils::SqrtF32(inputA[i]))); |
| 148 | } |
| 149 | } |
| 150 | |
| 151 | TEST_CASE("Test MeanF32") |
| 152 | { |
Richard Burton | c291144 | 2022-04-22 09:08:21 +0100 | [diff] [blame] | 153 | /* Test Constants: */ |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 154 | std::vector<float> input |
| 155 | {0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 0.000, 1.000}; |
| 156 | |
| 157 | /* Manually calculated mean of above vector */ |
| 158 | float expectedResult = 0.100; |
| 159 | CHECK (expectedResult == Approx(arm::app::math::MathUtils::MeanF32(input.data(), input.size()))); |
Richard Burton | c291144 | 2022-04-22 09:08:21 +0100 | [diff] [blame] | 160 | |
| 161 | /* Mean of 0 */ |
| 162 | std::vector<float> input2{1, 2, -1, -2}; |
| 163 | float expectedResult2 = 0.0f; |
| 164 | CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::MeanF32(input2.data(), input2.size()))); |
| 165 | |
| 166 | /* All 0s */ |
| 167 | std::vector<float> input3 = std::vector<float>(9, 0); |
| 168 | float expectedResult3 = 0.0f; |
| 169 | CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::MeanF32(input3.data(), input3.size()))); |
| 170 | |
| 171 | /* All 1s */ |
| 172 | std::vector<float> input4 = std::vector<float>(9, 1); |
| 173 | float expectedResult4 = 1.0f; |
| 174 | CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::MeanF32(input4.data(), input4.size()))); |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 175 | } |
| 176 | |
| 177 | TEST_CASE("Test StdDevF32") |
| 178 | { |
| 179 | /*Test Constants: */ |
| 180 | /* Normally distributed sample data generated by numpy normal library */ |
| 181 | std::vector<float> input |
| 182 | {1.62434536, -0.61175641, -0.52817175, -1.07296862, 0.86540763, -2.3015387, |
| 183 | 1.74481176, -0.7612069, 0.3190391, -0.24937038, 1.46210794, -2.06014071, |
| 184 | -0.3224172, -0.38405435, 1.13376944, -1.09989127, -0.17242821, -0.87785842, |
| 185 | 0.04221375, 0.58281521, -1.10061918, 1.14472371, 0.90159072, 0.50249434, |
| 186 | 0.90085595, -0.68372786, -0.12289023, -0.93576943, -0.26788808, 0.53035547, |
| 187 | -0.69166075,-0.39675353, -0.6871727, -0.84520564, -0.67124613, -0.0126646, |
| 188 | -1.11731035, 0.2344157, 1.65980218, 0.74204416, -0.19183555, -0.88762896, |
| 189 | -0.74715829, 1.6924546, 0.05080775, -0.63699565, 0.19091548, 2.10025514, |
| 190 | 0.12015895, 0.61720311 |
| 191 | }; |
| 192 | uint32_t inputLen = input.size(); |
| 193 | |
| 194 | /*Calculate mean using std library to avoid dependency on MathUtils::MeanF32 */ |
| 195 | float mean = (std::accumulate(input.begin(), input.end(), 0.0f))/float(inputLen); |
| 196 | float output = arm::app::math::MathUtils::StdDevF32(input.data(), inputLen, mean); |
| 197 | |
| 198 | /*Manually calculated standard deviation of above vector*/ |
| 199 | float expectedResult = 0.969589282958136; |
| 200 | |
| 201 | CHECK (expectedResult == Approx(output)); |
Richard Burton | c291144 | 2022-04-22 09:08:21 +0100 | [diff] [blame] | 202 | |
| 203 | /* All 0s should have 0 std dev. */ |
| 204 | std::vector<float> input2 = std::vector<float>(4, 0); |
| 205 | float expectedResult2 = 0.0f; |
| 206 | CHECK (expectedResult2 == Approx(arm::app::math::MathUtils::StdDevF32(input2.data(), input2.size(), 0.0f))); |
| 207 | |
| 208 | /* All 1s should have 0 std dev. */ |
| 209 | std::vector<float> input3 = std::vector<float>(4, 1); |
| 210 | float expectedResult3 = 0.0f; |
| 211 | CHECK (expectedResult3 == Approx(arm::app::math::MathUtils::StdDevF32(input3.data(), input3.size(), 1.0f))); |
| 212 | |
| 213 | /* Manually calclualted std value */ |
| 214 | std::vector<float> input4 {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}; |
| 215 | float mean2 = (std::accumulate(input4.begin(), input4.end(), 0.0f))/float(input4.size()); |
| 216 | float expectedResult4 = 2.872281323; |
| 217 | CHECK (expectedResult4 == Approx(arm::app::math::MathUtils::StdDevF32(input4.data(), input4.size(), mean2))); |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 218 | } |
| 219 | |
| 220 | TEST_CASE("Test FFT32") |
| 221 | { |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 222 | constexpr size_t nElem = 512; |
| 223 | |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 224 | /*Test Constants: */ |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 225 | std::vector<float> input_zeros(nElem, 0); |
| 226 | std::vector<float> input_ones(nElem, 1); |
| 227 | |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 228 | /* Random numbers generated using numpy rand with range [0:1] */ |
| 229 | std::vector<float> input_random{ |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 230 | 0.42333686, 0.6547418, 0.8933691, 0.91466254, 0.5992143, 0.99474055, 0.97750413, |
| 231 | 0.97160685, 0.72718734, 0.8699537, 0.643911, 0.09764466, 0.0050113136, 0.46823388, |
| 232 | 0.13709934, 0.44892532, 0.59728205, 0.04055081, 0.5579888, 0.18445836, 0.66469765, |
| 233 | 0.82715863, 0.91934484, 0.7844356, 0.23489648, 0.021708783, 0.67819905, 0.75761676, |
| 234 | 0.48374954, 0.14006922, 0.87082034, 0.7694296, 0.80479276, 0.8241704, 0.95917296, |
| 235 | 0.5758142, 0.16839339, 0.34290153, 0.5846108, 0.6878044, 0.1067114, 0.5198196, |
| 236 | 0.4356897, 0.68049103, 0.12480807, 0.3538696, 0.06067087, 0.056964435, 0.5382167, |
| 237 | 0.07761527, 0.6673144, 0.9045368, 0.11050189, 0.03530183, 0.07864744, 0.98752064, |
| 238 | 0.014321936, 0.101833574, 0.43293256, 0.87102246, 0.52411795, 0.90232223, 0.49560344, |
| 239 | 0.6803092, 0.2908511, 0.14653015, 0.99105513, 0.7057098, 0.09623502, 0.039713606, |
| 240 | 0.88669086, 0.56018597, 0.90632766, 0.99241334, 0.18748309, 0.38991618, 0.6359827, |
| 241 | 0.05665585, 0.732304, 0.2703365, 0.19014524, 0.5017947, 0.78862536, 0.81253093, |
| 242 | 0.35050204, 0.2832596, 0.65221876, 0.59856164, 0.42758793, 0.78865635, 0.30943435, |
| 243 | 0.93780816, 0.62568265, 0.35397422, 0.84209913, 0.48590583, 0.34837773, 0.5811646, |
| 244 | 0.42924216, 0.26692122, 0.030709852, 0.84459823, 0.09085059, 0.29297647, 0.48539516, |
| 245 | 0.33488297, 0.7877257, 0.8728821, 0.28454545, 0.7109578, 0.86097074, 0.8536262, |
| 246 | 0.4978063, 0.5760398, 0.77506036, 0.7716988, 0.27041402, 0.52340513, 0.2055419, |
| 247 | 0.8728235, 0.13492358, 0.79122984, 0.52998376, 0.33897072, 0.6426309, 0.8766521, |
| 248 | 0.89287037, 0.74047667, 0.42341164, 0.67437655, 0.4682156, 0.67123246, 0.54287183, |
| 249 | 0.3580476, 0.94756556, 0.2699457, 0.6131569, 0.75043845, 0.8115012, 0.49610943, |
| 250 | 0.7108478, 0.90941435, 0.02233071, 0.37346774, 0.33732748, 0.46691266, 0.35784695, |
| 251 | 0.39391598, 0.8556212, 0.884142, 0.11730601, 0.550112, 0.31513855, 0.69654715, |
| 252 | 0.58585805, 0.4493127, 0.78515726, 0.8176612, 0.9846698, 0.32842383, 0.41843212, |
| 253 | 0.48470423, 0.6757128, 0.95876855, 0.5989163, 0.13587572, 0.72886884, 0.88291156, |
| 254 | 0.34402263, 0.66211045, 0.86188424, 0.21498202, 0.26397392, 0.67372984, 0.91386956, |
| 255 | 0.7339788, 0.91308993, 0.1953016, 0.1539217, 0.214701, 0.58234113, 0.8019992, |
| 256 | 0.63969976, 0.041050985, 0.7293308, 0.26341477, 0.54768014, 0.97596467, 0.12385198, |
| 257 | 0.44149798, 0.5519762, 0.1697347, 0.577215, 0.8213594, 0.47874716, 0.64515114, |
| 258 | 0.61467725, 0.18463866, 0.23890929, 0.51052976, 0.16807361, 0.53142565, 0.2414274, |
| 259 | 0.41690814, 0.98815554, 0.6245643, 0.9477003, 0.24780034, 0.82469565, 0.8614785, |
| 260 | 0.9565832, 0.062440686, 0.9710724, 0.039196696, 0.11030199, 0.35234734, 0.02065066, |
| 261 | 0.12832293, 0.7328055, 0.48924434, 0.17247158, 0.5769348, 0.44146806, 0.53575355, |
| 262 | 0.17258933, 0.6980237, 0.86494404, 0.50573164, 0.5033998, 0.71199447, 0.41353586, |
| 263 | 0.26767612, 0.3789118, 0.046621118, 0.58491063, 0.22861995, 0.03134273, 0.53280216, |
| 264 | 0.23382367, 0.07748905, 0.96875405, 0.6613716, 0.64087844, 0.8377165, 0.051519375, |
| 265 | 0.68997836, 0.3776376, 0.43362603, 0.5358754, 0.51419014, 0.12823892, 0.26574057, |
| 266 | 0.508808, 0.15734084, 0.78327274, 0.5045347, 0.5445746, 0.89297736, 0.8531272, |
| 267 | 0.91270804, 0.87429863, 0.3965137, 0.13544834, 0.74269205, 0.80592203, 0.045050766, |
| 268 | 0.13362087, 0.17090783, 0.02873757, 0.99339336, 0.6394376, 0.48203012, 0.70598215, |
| 269 | 0.37082237, 0.39792424, 0.89938444, 0.312602, 0.48755112, 0.18220617, 0.17303479, |
| 270 | 0.31954846, 0.78080165, 0.1755106, 0.68262285, 0.84665287, 0.8520143, 0.8459509, |
| 271 | 0.39417005, 0.30087698, 0.81362164, 0.61927587, 0.32739028, 0.9023775, 0.27578092, |
| 272 | 0.6830477, 0.15842387, 0.8473049, 0.43057114, 0.2019703, 0.20560141, 0.6237757, |
| 273 | 0.60283095, 0.27645138, 0.26605442, 0.27985683, 0.41353813, 0.85139906, 0.71711886, |
| 274 | 0.5444832, 0.73613757, 0.7397004, 0.7406752, 0.41016674, 0.31896713, 0.4541723, |
| 275 | 0.2795807, 0.47941738, 0.00504193, 0.89091027, 0.8097144, 0.63033766, 0.37252298, |
| 276 | 0.9132861, 0.5102532, 0.04104481, 0.30368647, 0.21573475, 0.99520445, 0.5047808, |
| 277 | 0.6868845, 0.99881023, 0.30377692, 0.2554386, 0.47201005, 0.11120686, 0.10077732, |
| 278 | 0.1853349, 0.49159425, 0.3938629, 0.8989509, 0.9887155, 0.698771, 0.695701, |
| 279 | 0.78368753, 0.52537227, 0.19451462, 0.3659248, 0.1968508, 0.7751828, 0.33103722, |
| 280 | 0.40406147, 0.37832898, 0.68663514, 0.32225925, 0.41771907, 0.034218453, 0.42808908, |
| 281 | 0.20685343, 0.1861495, 0.045986768, 0.8532299, 0.17200677, 0.44670314, 0.56831235, |
| 282 | 0.5388232, 0.5430553, 0.69175136, 0.6462231, 0.42827028, 0.10050113, 0.30627027, |
| 283 | 0.9967943, 0.6684778, 0.5928422, 0.63392985, 0.99123496, 0.79301435, 0.7936309, |
| 284 | 0.42839453, 0.39781123, 0.22329247, 0.0122212395, 0.2807108, 0.19812097, 0.5576105, |
| 285 | 0.115653396, 0.3732018, 0.7622857, 0.19847734, 0.5310287, 0.7298145, 0.5518292, |
| 286 | 0.9117333, 0.13215758, 0.33716795, 0.42372775, 0.6779287, 0.35799992, 0.097887225, |
| 287 | 0.20171605, 0.9948177, 0.1829232, 0.80349857, 0.9807098, 0.22959666, 0.67322475, |
| 288 | 0.63094735, 0.93454355, 0.15962408, 0.04335433, 0.47104993, 0.36784375, 0.45258796, |
| 289 | 0.93415564, 0.1655446, 0.7195017, 0.76236975, 0.3846913, 0.01330617, 0.84716374, |
| 290 | 0.1227003, 0.65102947, 0.6632434, 0.3728453, 0.4222391, 0.6942989, 0.16014872, |
| 291 | 0.10798196, 0.94033676, 0.026525471, 0.8379024, 0.5484514, 0.13500613, 0.22919805, |
| 292 | 0.7001831, 0.6573261, 0.38086265, 0.8725666, 0.35077834, 0.28415123, 0.42283052, |
| 293 | 0.668379, 0.9769895, 0.37621376, 0.646407, 0.11188069, 0.17129017, 0.7441628, |
| 294 | 0.25617477, 0.7751679, 0.8565412, 0.67631435, 0.45213568, 0.61896557, 0.3387995, |
| 295 | 0.51607716, 0.60779697, 0.16428445, 0.5080923, 0.13012086, 0.61184275, 0.7690249, |
| 296 | 0.9578811, 0.67365676, 0.16241212, 0.97157824, 0.5595742, 0.75936574, 0.6043881, |
| 297 | 0.2149638, 0.4925318, 0.58727825, 0.97953695, 0.01605968, 0.2819307, 0.6448378, |
| 298 | 0.4265335, 0.661541, 0.3976571, 0.40607136, 0.46425515, 0.2055872, 0.2716193, |
| 299 | 0.4132582, 0.8372537, 0.37787434, 0.082228854, 0.7985557, 0.9718134, 0.35222608, |
| 300 | 0.4853643, 0.2569464, 0.14783978, 0.4889042, 0.62900156, 0.19994198, 0.4618481, |
| 301 | 0.21673755, 0.51749533, 0.1260157, 0.83759904, 0.36438805, 0.6704668, 0.22010763, |
| 302 | 0.2359318, 0.53004104, 0.9723652, 0.91218954, 0.9153926, 0.48207277, 0.34850466, |
| 303 | 0.8939421}; |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 304 | std::vector<std::vector<float>> input_vectors{input_zeros, |
| 305 | input_ones, |
| 306 | input_random}; |
| 307 | |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 308 | std::vector<float> output_zeros(nElem, 0); |
| 309 | std::vector<float> output_ones(nElem, 0); |
| 310 | std::vector<float> output_random(nElem, 0); |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 311 | std::vector<std::vector<float>> output_vectors{output_zeros, |
| 312 | output_ones, |
| 313 | output_random}; |
| 314 | |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 315 | std::vector<float> expected_result_zeros(nElem, 0); |
| 316 | std::vector<float> expected_result_ones(nElem, 0); |
| 317 | expected_result_ones[0] = static_cast<float>(nElem); |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 318 | |
| 319 | /* Values are stored as [real0, realN/2, real1, im1, real2, im2, ...] */ |
| 320 | std::vector<float> expected_result_random{ |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 321 | 259.510161, 2.59796867, 2.55982143, -5.91349888, -1.80049237, 1.09902763, |
| 322 | 4.0094324, 7.76684892, 4.32617219, -4.33636417, 2.98128463, -0.83763449, |
| 323 | 2.92973078, -3.75655459, 2.27203161, -4.61106145, 5.55562176, -4.71880166, |
| 324 | 2.13693416, -2.20496619, 2.31174036, -3.52991041, -0.61687068, 2.43455407, |
| 325 | 4.76317833, 8.66518565, 1.72350562, -1.33641312, 5.82836675, 0.89396187, |
| 326 | 8.15031483, 4.34599034, 5.99780199, 2.94900065, -0.0234462045, -5.03789597, |
| 327 | 12.190702, -6.47012928, 1.24434715, 0.23621713, 0.920921279, -9.20510398, |
| 328 | -0.773267254, -3.72141078, 6.28883709, 4.00634065, 4.46491682, 0.74625307, |
| 329 | -0.587506158, -4.22833058, 3.15189786, -1.82518672, -6.9378226, 2.4170692, |
| 330 | 3.23045185, 3.33383799, 0.0510059531, -3.4233929, -2.91651323, -0.0258584, |
| 331 | 5.84499843, -9.51454903, -14.9214047, 5.52200123, 4.5217959, -7.08268703, |
| 332 | 0.51677542, -2.90878759, 5.04314682, -1.16928599, -10.7329243, -0.2719951, |
| 333 | -3.95269565, -2.32475678, -4.11031641, -2.20538835, -0.589005095, -5.65483456, |
| 334 | -10.8927018, 5.74801823, -5.72520347, -3.94970165, -0.518407515, 1.23622633, |
| 335 | 9.56297959, 0.24424306, -3.74306351, 9.63476301, -4.74493837, -0.35443496, |
| 336 | 6.54760504, 1.16188913, -13.341695, 7.19088609, -2.560458, -0.49557866, |
| 337 | 2.93460322, 6.91076746, -0.284779221, 6.59958391, 1.70963995, -7.67293252, |
| 338 | 4.1850079, -0.14627552, -1.24855113, -1.43322867, 0.360644904, -4.11521374, |
| 339 | -10.0628421, 2.87531563, -9.34809732, -0.58251846, 3.61799848, -5.10288284, |
| 340 | -5.96239076, 1.99792128, -0.0783229243, 1.81741166, 2.32709681, 0.68487206, |
| 341 | -3.08398468, 1.5177629, 1.41015388, 4.51146401, 1.90769911, 0.56093423, |
| 342 | 3.58389141, -0.14974575, -2.20163907, -1.62177814, -1.91904127, 1.94645907, |
| 343 | -0.13772293, -4.30291678, -2.61843435, 0.12691241, 3.28959117, 4.7309582, |
| 344 | -2.93995652, 0.39835926, 2.89711768, 1.42284586, -5.13129145, -7.26477374, |
| 345 | 3.74616158, -2.59659457, 3.8574875, -2.93737277, 3.17748694, -4.45041455, |
| 346 | -2.68466437, 1.37377726, 1.60008368, 1.63787578, -1.95661401, -6.34937202, |
| 347 | -2.62744282, -5.20892662, 0.890553959, -6.37113573, 2.35885332, -10.04547561, |
| 348 | 0.329866159, 1.89217741, 0.882516491, -3.53298728, -2.22525608, -5.64794388, |
| 349 | -5.19226843, 2.5971315, 4.49346648, -0.20428409, -6.14851885, -11.90893875, |
| 350 | 3.75899776, -1.86910056, 2.78518535, -2.6359501, -2.13423317, 4.86509946, |
| 351 | -2.37625499, 7.42404308, 6.71175474, 6.06191618, 2.59014379, 4.76329698, |
| 352 | 9.19140042, 9.69149015, 3.33307819, -4.03094924, 2.12988453, -0.15820258, |
| 353 | -10.3422801, -3.04462388, 3.59852152, -2.00887343, -3.69998656, 0.90050102, |
| 354 | 0.679959099, -1.88604949, 1.24235316, 0.41309537, 6.13876866, -7.1040085, |
| 355 | 6.17728674, 1.91667103, -1.32895472, -0.17674661, -6.94720428, 3.10502593, |
| 356 | -2.33990738, -1.27840434, 3.2144252, 2.14102714, 2.37498837, 3.8158066, |
| 357 | -2.24107675, -5.52527559, -2.9569793, -0.50367608, 3.01687661, 7.08195792, |
| 358 | 6.7860479, -3.94154162, 2.24402195, 4.60132638, 3.42211139, 4.17689039, |
| 359 | -1.17277194, 2.15404472, -2.3748193, 1.42611867, -0.463033506, 3.21563035, |
| 360 | 1.38662123, 3.98598717, -3.75283402, -2.47600433, -1.97290542, 2.83361487, |
| 361 | -0.845662834, 5.57411581, -0.972981483, -11.394208, 1.88220611, -0.80225125, |
| 362 | -0.434295854, -8.2954126, 3.81795409, -3.17146, -4.61994107, -1.59820505, |
| 363 | -5.98834455, -4.93129451, -0.513862996, -0.15649305, -5.59094391, 6.25244435, |
| 364 | -6.59974456, 13.17193115, 4.48609092, 1.64741879, 7.40985006, 0.44896188, |
| 365 | 3.81058449, -0.76425931, -5.47938416, 4.01447941, -3.21535548, -1.45542238, |
| 366 | 0.72274083, -0.23983128, -4.32373034, 0.1337671, -5.89365226, -3.18756318, |
| 367 | 7.90979161, 5.27570134, -3.43094553, -6.00826981, 1.17932561, -3.50027177, |
| 368 | 0.181306385, 1.1062498, 0.723650536, -1.55500613, -3.88047911, -2.43746762, |
| 369 | -6.81565579, 2.16343352, 2.46366137, 2.38704469, -2.55106395, 6.5091449, |
| 370 | -2.06510578, 11.11320924, 2.06649835, -1.05026064, 1.63564303, -0.04638729, |
| 371 | 1.45053876, 0.43730146, 1.25027939, 0.79932743, 2.81088838, 6.95136058, |
| 372 | -4.41417255, 2.89610628, 1.15426258, -2.60704937, -2.77744882, 4.12872365, |
| 373 | -2.98288336, 6.75607352, 2.36553382, -2.10540332, -7.30042988, -5.44897893, |
| 374 | 3.44048454, 4.29726231, 2.181995, -0.80126759, -4.04051175, 4.57584864, |
| 375 | 0.956312116, -4.45183318, 3.42348929, -9.84138181, 8.69604433, -6.6481311, |
| 376 | 0.468232735, 1.41031176, -1.240857, 10.61672181, 0.356591473, 10.51631421, |
| 377 | -0.99743547, 2.72157537, 8.63583929, -2.19404252, -1.53605811, -4.41068581, |
| 378 | 2.05371873, 1.25665769, 1.65289503, 4.52520582, -0.535062642, 0.82084677, |
| 379 | -11.0079476, -5.09361474, 9.63129107, 3.90056638, -4.19779738, 0.06565745, |
| 380 | 2.42526917, 2.5854233, -3.66709357, 3.80502971, -0.101489353, -6.85423228, |
| 381 | -13.9361494, -0.43904617, 6.01800968, -1.30751495, -4.75122234, 2.74740671, |
| 382 | 5.54971138, 9.43409003, -0.994733058, -3.0096825, -8.60263376, 0.36653762, |
| 383 | 3.53318614, -2.69194556, -8.9514574, -4.71570923, 5.15417709, 2.68645385, |
| 384 | -2.78042293, 8.21739385, 0.590225003, 2.13319153, 1.72158888, 0.18114627, |
| 385 | 3.92269446, 3.3525857, -3.40313825, 4.39280934, -1.70368966, 1.29121245, |
| 386 | -3.11326453, -1.85941318, 4.57078881, 0.72531039, 6.00445664, 4.9588524, |
| 387 | -3.32944491, -0.02080722, -7.42374632, -3.23290026, -0.81614579, 3.55935439, |
| 388 | -0.619206533, 2.42859073, 2.21486456, 3.76402487, 3.90930695, -3.61610186, |
| 389 | -0.812712547, -14.63377988, 1.14460823, -3.14089899, 3.18097435, 1.21957751, |
| 390 | 2.85181833, 0.89990235, -4.32147361, -5.54219361, -1.12253677, 2.96141081, |
| 391 | -4.4257707, 3.17282306, 4.9174671, 1.16977744, -4.55148089, -2.82520179, |
| 392 | -1.71684103, 1.91487668, 0.770726836, 0.78534837, -5.91048566, -4.8288477, |
| 393 | -1.35560162, -3.60938315, 1.15812301, 2.44299541, 1.3611519, -5.40950935, |
| 394 | 7.08292127, 0.27720591, -0.160210828, 2.75862348, -1.57403782, 9.97207524, |
| 395 | -2.08957576, 8.70299964, -5.33004663, 4.1547783, 3.51580675, -5.10788085, |
| 396 | 4.37938353, -3.73449894, 1.44673271, 0.51941469, 0.852232446, -1.1134965, |
| 397 | -1.43972745, -1.62952127, -2.50759973, 1.19012213, 0.572772282, -2.71833059, |
| 398 | -6.8471899, 4.2621535, 1.58954734, -0.53827818, -0.144624396, 7.63866979, |
| 399 | 0.410423977, -2.4785678, -5.02681867, -2.03469811, 0.959505727, 2.68589705, |
| 400 | 3.20889444, -10.76452533, -3.84771551, 2.49189796, -3.19895938, -3.49948794, |
| 401 | -2.6723897, 5.11386526, -3.85957031, -1.40741978, 0.176663166, -11.7111276, |
| 402 | 0.639997364, -1.30321198, 3.20767633, 1.65750671, -11.6187257, -4.36634782, |
| 403 | -3.18675281, -4.89279155, -4.08760307, 2.19269283, -1.5892487, 0.17948212, |
| 404 | 4.81376107, 2.01871001, -0.324211095, -0.2790092, 1.12603878, -3.61503491, |
| 405 | -2.86982317, -3.03634532, 8.0771391, 2.21302089, 2.91496011, -2.58564072, |
| 406 | 0.0, 0.0}; |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 407 | |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 408 | std::vector<std::vector<float>> expected_results_vectors{ |
| 409 | expected_result_zeros, expected_result_ones, expected_result_random}; |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 410 | arm::app::math::FftInstance fftInstance; |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 411 | |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 412 | /* Iterate over each of the input vectors, calculate FFT and compare with corresponding expected_results vectors */ |
Kshitij Sisodia | b59ba68 | 2021-11-23 17:19:52 +0000 | [diff] [blame] | 413 | for (size_t j = 0; j < input_vectors.size(); j++) { |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 414 | uint16_t fftLen = input_vectors[j].size(); |
| 415 | arm::app::math::MathUtils::FftInitF32(fftLen, fftInstance); |
| 416 | arm::app::math::MathUtils::FftF32(input_vectors[j], output_vectors[j], fftInstance); |
| 417 | |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 418 | const float tolerance = 10e-4; |
Kshitij Sisodia | b59ba68 | 2021-11-23 17:19:52 +0000 | [diff] [blame] | 419 | for (size_t i = 0; i < fftLen/2; i++) { |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 420 | CHECK(output_vectors[j][i] == Approx(expected_results_vectors[j][i]).margin(tolerance)); |
| 421 | } |
| 422 | } |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 423 | |
Kshitij Sisodia | 1e6c694 | 2023-05-05 10:56:37 +0100 | [diff] [blame^] | 424 | /* Test inverse FFTs using the forward FFT for complex numbers. |
| 425 | * IFFT(XVec) = (1/N)(Conj(FFT(Conj(Xvec)))) */ |
| 426 | for (size_t j = 0; j < input_vectors.size(); j++) { |
| 427 | const uint16_t fftLen = input_vectors[j].size(); |
| 428 | const size_t inputSz = fftLen * 2; |
| 429 | |
| 430 | /* This vector will populate the input for FFT for complex numbers. */ |
| 431 | std::vector<float> inputWithConjugates(inputSz); |
| 432 | |
| 433 | /* We expect the output of this test to return the original input. */ |
| 434 | std::vector<float> expectedOutputVector = input_vectors[j]; |
| 435 | |
| 436 | /* Placeholder for output vector. */ |
| 437 | std::vector<float> outputVector(inputWithConjugates.size()); |
| 438 | |
| 439 | /* Populate the 0 and N/2 elements (these will be real numbers |
| 440 | * only - no imaginary parts. */ |
| 441 | inputWithConjugates[0] = expected_results_vectors[j][0]; |
| 442 | inputWithConjugates[fftLen] = expected_results_vectors[j][1]; |
| 443 | |
| 444 | /* Populate the rest of the elements - conjugates of the original for the left mirror |
| 445 | * and the right side with what the left would have been, i.e., |
| 446 | * conjugate(conjugate(X_left)). */ |
| 447 | for (size_t i = 2; i < fftLen; i += 2) { |
| 448 | inputWithConjugates[i] = expected_results_vectors[j][i]; |
| 449 | inputWithConjugates[i + 1] = 0 - expected_results_vectors[j][i + 1]; |
| 450 | inputWithConjugates[fftLen + i] = expected_results_vectors[j][fftLen - i]; |
| 451 | inputWithConjugates[fftLen + i + 1] = expected_results_vectors[j][fftLen - i + 1]; |
| 452 | } |
| 453 | |
| 454 | arm::app::math::MathUtils::FftInitF32( |
| 455 | fftLen, fftInstance, arm::app::math::FftType::complex); |
| 456 | arm::app::math::MathUtils::FftF32(inputWithConjugates, outputVector, fftInstance); |
| 457 | |
| 458 | const float tolerance = 0.1; |
| 459 | for (size_t i = 0; i < expectedOutputVector.size(); i++) { |
| 460 | |
| 461 | /* The number returned here will be nElem times the output. */ |
| 462 | CHECK(outputVector[i * 2] / static_cast<float>(nElem) == |
| 463 | Approx(expectedOutputVector[i]).margin(tolerance)); |
| 464 | |
| 465 | /* The imaginary part here should be close to 0 as the original input |
| 466 | * we supplied was real. */ |
| 467 | CHECK(outputVector[i * 2 + 1] / nElem == Approx(0.f).margin(tolerance)); |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 468 | } |
| 469 | } |
| 470 | } |
| 471 | |
| 472 | TEST_CASE("Test VecLogarithmF32") |
| 473 | { |
| 474 | /*Test Constants: */ |
| 475 | std::vector<float> input = |
| 476 | { 0.1e-10, 0.5, 1, M_PI, M_E }; |
| 477 | std::vector<float> expectedResult = |
| 478 | {-25.328436, -0.693147181, 0, 1.144729886, 1}; |
| 479 | std::vector<float> output(input.size()); |
| 480 | |
| 481 | arm::app::math::MathUtils::VecLogarithmF32(input,output); |
| 482 | |
| 483 | for (size_t i = 0; i < input.size(); i++) |
| 484 | CHECK (expectedResult[i] == Approx(output[i])); |
| 485 | } |
| 486 | |
| 487 | TEST_CASE("Test DotProductF32") |
| 488 | { |
| 489 | /*Test Constants: */ |
| 490 | std::vector<float> inputA |
| 491 | {1,1,1,0,0,0}; |
| 492 | std::vector<float> inputB |
| 493 | {0,0,0,1,1,1}; |
| 494 | uint32_t len = inputA.size(); |
| 495 | |
| 496 | float expectedResult = 0; |
| 497 | float dot_prod = arm::app::math::MathUtils::DotProductF32(inputA.data(), inputB.data(), len); |
| 498 | CHECK(dot_prod == expectedResult); |
| 499 | } |
| 500 | |
| 501 | TEST_CASE("Test ComplexMagnitudeSquaredF32") |
| 502 | { |
| 503 | /*Test Constants: */ |
| 504 | std::vector<float> input |
| 505 | {0.0, 0.0, 0.5, 0.5,1,1}; |
Kshitij Sisodia | b59ba68 | 2021-11-23 17:19:52 +0000 | [diff] [blame] | 506 | size_t inputLen = input.size(); |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 507 | |
| 508 | std::vector<float> expectedResult |
| 509 | {0.0, 0.5, 2,}; |
Kshitij Sisodia | b59ba68 | 2021-11-23 17:19:52 +0000 | [diff] [blame] | 510 | size_t outputLen = inputLen/2; |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 511 | std::vector<float>output(outputLen); |
| 512 | |
| 513 | /* Pass pointers to input/output vectors as this function over-writes the first half |
| 514 | * of the input vector with output results */ |
| 515 | arm::app::math::MathUtils::ComplexMagnitudeSquaredF32(input.data(), inputLen, output.data(), outputLen); |
| 516 | |
Kshitij Sisodia | b178b28 | 2022-01-04 13:37:53 +0000 | [diff] [blame] | 517 | for (size_t i = 0; i < outputLen; i++) { |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 518 | CHECK (expectedResult[i] == Approx(output[i])); |
Kshitij Sisodia | b178b28 | 2022-01-04 13:37:53 +0000 | [diff] [blame] | 519 | } |
| 520 | } |
| 521 | |
| 522 | /** |
| 523 | * @brief Simple function to test the Softmax function |
| 524 | * |
| 525 | * @param input Input vector |
| 526 | * @param goldenOutput Expected output vector |
| 527 | */ |
| 528 | static void TestSoftmaxF32( |
| 529 | const std::vector<float>& input, |
| 530 | const std::vector<float>& goldenOutput) |
| 531 | { |
| 532 | std::vector<float> output = input; /* Function modifies the vector in-place */ |
| 533 | arm::app::math::MathUtils::SoftmaxF32(output); |
| 534 | |
| 535 | for (size_t i = 0; i < goldenOutput.size(); ++i) { |
| 536 | CHECK(goldenOutput[i] == Approx(output[i])); |
| 537 | } |
| 538 | |
| 539 | REQUIRE(output.size() == goldenOutput.size()); |
| 540 | } |
| 541 | |
| 542 | TEST_CASE("Test SoftmaxF32") |
| 543 | { |
| 544 | SECTION("Simple series") { |
| 545 | const std::vector<float> input { |
| 546 | 0.0, 1.0, 2.0, 3.0, 4.0, |
| 547 | 5.0, 6.0, 7.0, 8.0, 9.0 |
| 548 | }; |
| 549 | |
| 550 | const std::vector<float> expectedOutput { |
| 551 | 7.80134161e-05, 2.12062451e-04, |
| 552 | 5.76445508e-04, 1.56694135e-03, |
| 553 | 4.25938820e-03, 1.15782175e-02, |
| 554 | 3.14728583e-02, 8.55520989e-02, |
| 555 | 2.32554716e-01, 6.32149258e-01 |
| 556 | }; |
| 557 | |
| 558 | TestSoftmaxF32(input, expectedOutput); |
| 559 | } |
| 560 | |
| 561 | SECTION("Random series") { |
| 562 | const std::vector<float> input { |
| 563 | 0.8810943246170809, 0.5877587675947015, |
| 564 | 0.6841546454788743, 0.4155920960071594, |
| 565 | 0.9799415323651671, 0.5066432973545711, |
| 566 | 0.3846024252355448, 0.4568689569632123, |
| 567 | 0.3284413744557605, 0.49152323726213554 |
| 568 | }; |
| 569 | |
| 570 | const std::vector<float> expectedOutput { |
| 571 | 0.13329595, 0.09940837, |
| 572 | 0.10946799, 0.08368583, |
| 573 | 0.14714509, 0.09166319, |
| 574 | 0.08113220, 0.08721240, |
| 575 | 0.07670132, 0.09028766 |
| 576 | }; |
| 577 | |
| 578 | TestSoftmaxF32(input, expectedOutput); |
| 579 | } |
| 580 | |
| 581 | SECTION("Series with large STD") { |
| 582 | const std::vector<float> input { |
| 583 | 0.001, 1000.000 |
| 584 | }; |
| 585 | |
| 586 | const std::vector<float> expectedOutput { |
| 587 | 0.000, 1.000 |
| 588 | }; |
| 589 | |
| 590 | TestSoftmaxF32(input, expectedOutput); |
| 591 | } |
liabar01 | dee53bc | 2021-10-29 15:59:04 +0100 | [diff] [blame] | 592 | } |