George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #include <catch.hpp> |
| 7 | #include <map> |
| 8 | #include <cinttypes> |
| 9 | #include "KeywordSpottingPipeline.hpp" |
| 10 | #include "DsCNNPreprocessor.hpp" |
| 11 | |
| 12 | static std::string GetResourceFilePath(const std::string& filename) |
| 13 | { |
| 14 | std::string testResources = TEST_RESOURCE_DIR; |
| 15 | if (testResources.empty()) |
| 16 | { |
| 17 | throw std::invalid_argument("Invalid test resources directory provided"); |
| 18 | } |
| 19 | else |
| 20 | { |
| 21 | if(testResources.back() != '/') |
| 22 | { |
| 23 | return testResources + "/" + filename; |
| 24 | } |
| 25 | else |
| 26 | { |
| 27 | return testResources + filename; |
| 28 | } |
| 29 | } |
| 30 | } |
| 31 | |
| 32 | TEST_CASE("Test Keyword spotting pipeline") |
| 33 | { |
| 34 | const int8_t ifm0_kws [] = |
| 35 | { |
| 36 | -0x1b, 0x4f, 0x7a, -0x55, 0x6, -0x11, 0x6e, -0x6, 0x67, -0x7e, -0xd, 0x6, 0x49, 0x79, -0x1e, 0xe, |
| 37 | 0x1d, 0x6e, 0x6f, 0x6f, -0x2e, -0x4b, 0x2, -0x3e, 0x40, -0x4b, -0x7, 0x31, -0x38, -0x64, -0x28, |
| 38 | 0xc, -0x1d, 0xf, 0x1c, 0x5a, -0x4b, 0x56, 0x7e, 0x9, -0x29, 0x13, -0x65, -0xa, 0x34, -0x59, 0x41, |
| 39 | -0x6f, 0x75, 0x67, -0x5f, 0x17, 0x4a, -0x76, -0x7a, 0x49, -0x19, -0x41, 0x78, 0x40, 0x44, 0xe, |
| 40 | -0x51, -0x5c, 0x3d, 0x24, 0x76, -0x66, -0x11, 0x5e, 0x7b, -0x4, 0x7a, 0x9, 0x13, 0x8, -0x21, -0x11, |
| 41 | 0x13, 0x7a, 0x25, 0x6, -0x68, 0x6a, -0x30, -0x16, -0x43, -0x27, 0x4c, 0x6b, -0x14, -0x12, -0x5f, |
| 42 | 0x49, -0x2a, 0x44, 0x57, -0x78, -0x72, 0x62, -0x8, -0x38, -0x73, -0x2, -0x80, 0x79, -0x3f, 0x57, |
| 43 | 0x9, -0x7e, -0x34, -0x59, 0x19, -0x66, 0x58, -0x3b, -0x69, -0x1a, 0x13, -0x2f, -0x2f, 0x13, 0x35, |
| 44 | -0x30, 0x1e, 0x3b, -0x71, 0x67, 0x7d, -0x5d, 0x1a, 0x69, -0x53, -0x38, -0xf, 0x76, 0x2, 0x7e, 0x45, |
| 45 | -0xa, 0x59, -0x6b, -0x28, -0x5d, -0x63, -0x7d, -0x3, 0x48, 0x74, -0x75, -0x7a, 0x1f, -0x53, 0x5b, |
| 46 | 0x4d, -0x18, -0x4a, 0x39, -0x52, 0x5a, -0x6b, -0x41, -0x3e, -0x61, -0x80, -0x52, 0x67, 0x71, -0x47, |
| 47 | 0x79, -0x41, 0x3a, -0x8, -0x1f, 0x4d, -0x7, 0x5b, 0x6b, -0x1b, -0x8, -0x20, -0x21, 0x7c, -0x74, |
| 48 | 0x25, -0x68, -0xe, -0x7e, -0x45, -0x28, 0x45, -0x1a, -0x39, 0x78, 0x11, 0x48, -0x6b, -0x7b, -0x43, |
| 49 | -0x21, 0x38, 0x46, 0x7c, -0x5d, 0x59, 0x53, -0x3f, -0x15, 0x59, -0x17, 0x75, 0x2f, 0x7c, 0x68, 0x6a, |
| 50 | 0x0, -0x10, 0x5b, 0x61, 0x36, -0x41, 0x33, 0x23, -0x80, -0x1d, -0xb, -0x56, 0x2d, 0x68, -0x68, |
| 51 | 0x2f, 0x48, -0x5d, -0x44, 0x64, -0x27, 0x68, -0x13, 0x39, -0x3f, 0x18, 0x31, 0x15, -0x78, -0x2, |
| 52 | 0x72, 0x60, 0x59, -0x30, -0x22, 0x73, 0x61, 0x76, -0x4, -0x62, -0x64, -0x80, -0x32, -0x16, 0x51, |
| 53 | -0x2, -0x70, 0x71, 0x3f, -0x5f, -0x35, -0x3c, 0x79, 0x48, 0x61, 0x5b, -0x20, -0x1e, -0x68, -0x1c, |
| 54 | 0x6c, 0x3a, 0x28, -0x36, -0x3e, 0x5f, -0x75, -0x73, 0x1e, 0x75, -0x66, -0x22, 0x20, -0x64, 0x67, |
| 55 | 0x36, 0x14, 0x37, -0xa, -0xe, 0x8, -0x37, -0x43, 0x21, -0x8, 0x54, 0x1, 0x34, -0x2c, -0x73, -0x11, |
| 56 | -0x48, -0x1c, -0x40, 0x14, 0x4e, -0x53, 0x25, 0x5e, 0x14, 0x4f, 0x7c, 0x6d, -0x61, -0x38, 0x35, |
| 57 | -0x5a, -0x44, 0x12, 0x52, -0x60, 0x22, -0x1c, -0x8, -0x4, -0x6b, -0x71, 0x43, 0xb, 0x7b, -0x7, |
| 58 | -0x3c, -0x3b, -0x40, -0xd, 0x44, 0x6, 0x30, 0x38, 0x57, 0x1f, -0x7, 0x2, 0x4f, 0x64, 0x7c, -0x3, |
| 59 | -0x13, -0x71, -0x45, -0x53, -0x52, 0x2b, -0x11, -0x1d, -0x2, -0x29, -0x37, 0x3d, 0x19, 0x76, 0x18, |
| 60 | 0x1d, 0x12, -0x29, -0x5e, -0x54, -0x48, 0x5d, -0x41, -0x3f, 0x7e, -0x2a, 0x41, 0x57, -0x65, -0x15, |
| 61 | 0x12, 0x1f, -0x57, 0x79, -0x64, 0x3a, -0x2f, 0x7f, -0x6c, 0xa, 0x52, -0x1f, -0x41, 0x6e, -0x4b, |
| 62 | 0x3d, -0x1b, -0x42, 0x22, -0x3c, -0x35, -0xf, 0xc, 0x32, -0x15, -0x68, -0x21, 0x0, -0x16, 0x14, |
| 63 | -0x10, -0x5b, 0x2f, 0x21, 0x41, -0x8, -0x12, -0xa, 0x10, 0xf, 0x7e, -0x76, -0x1d, 0x2b, -0x49, |
| 64 | 0x42, -0x25, -0x78, -0x69, -0x2c, 0x3f, 0xc, 0x52, 0x6d, 0x2e, -0x13, 0x76, 0x37, -0x36, -0x51, |
| 65 | -0x5, -0x63, -0x4f, 0x1c, 0x6b, -0x4b, 0x71, -0x12, 0x72, -0x3f,-0x4a, 0xf, 0x3a, -0xd, 0x38, 0x3b, |
| 66 | -0x5d, 0x75, -0x43, -0x10, -0xa, -0x7a, 0x1a, -0x44, 0x1c, 0x6a, 0x43, -0x1b, -0x35, 0x7d, -0x2c, |
| 67 | -0x10, 0x5b, -0x42, -0x4f, 0x69, 0x1f, 0x1b, -0x64, -0x21, 0x19, -0x5d, 0x2e, -0x2a, -0x65, -0x13, |
| 68 | -0x70, -0x6e |
| 69 | }; |
| 70 | |
| 71 | const int8_t ofm0_kws [] = |
| 72 | { |
| 73 | -0x80, 0x7f, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80 |
| 74 | }; |
| 75 | |
| 76 | // First 640 samples from yes.wav. |
| 77 | std::vector<int16_t> testWav = std::vector<int16_t> |
| 78 | { |
| 79 | 139, 143, 164, 163, 157, 156, 151, 148, 172, 171, |
| 80 | 165, 169, 149, 142, 145, 147, 166, 146, 112, 132, |
| 81 | 132, 136, 165, 176, 176, 152, 138, 158, 179, 185, |
| 82 | 183, 148, 121, 130, 167, 204, 163, 132, 165, 184, |
| 83 | 193, 205, 210, 204, 195, 178, 168, 197, 207, 201, |
| 84 | 197, 177, 185, 196, 191, 198, 196, 183, 193, 181, |
| 85 | 157, 170, 167, 159, 164, 152, 146, 167, 180, 171, |
| 86 | 194, 232, 204, 173, 171, 172, 184, 169, 175, 199, |
| 87 | 200, 195, 185, 214, 214, 193, 196, 191, 204, 191, |
| 88 | 172, 187, 183, 192, 203, 172, 182, 228, 232, 205, |
| 89 | 177, 174, 191, 210, 210, 211, 197, 177, 198, 217, |
| 90 | 233, 236, 203, 191, 169, 145, 149, 161, 198, 206, |
| 91 | 176, 137, 142, 181, 200, 215, 201, 188, 166, 162, |
| 92 | 184, 155, 135, 132, 126, 142, 169, 184, 172, 156, |
| 93 | 132, 119, 150, 147, 154, 160, 125, 130, 137, 154, |
| 94 | 161, 168, 195, 182, 160, 134, 138, 146, 130, 120, |
| 95 | 101, 122, 137, 118, 117, 131, 145, 140, 146, 148, |
| 96 | 148, 168, 159, 134, 114, 114, 130, 147, 147, 134, |
| 97 | 125, 98, 107, 127, 99, 79, 84, 107, 117, 114, |
| 98 | 93, 92, 127, 112, 109, 110, 96, 118, 97, 87, |
| 99 | 110, 95, 128, 153, 147, 165, 146, 106, 101, 137, |
| 100 | 139, 96, 73, 90, 91, 51, 69, 102, 100, 103, |
| 101 | 96, 101, 123, 107, 82, 89, 118, 127, 99, 100, |
| 102 | 111, 97, 111, 123, 106, 121, 133, 103, 100, 88, |
| 103 | 85, 111, 114, 125, 102, 91, 97, 84, 139, 157, |
| 104 | 109, 66, 72, 129, 111, 90, 127, 126, 101, 109, |
| 105 | 142, 138, 129, 159, 140, 80, 74, 78, 76, 98, |
| 106 | 68, 42, 106, 143, 112, 102, 115, 114, 82, 75, |
| 107 | 92, 80, 110, 114, 66, 86, 119, 101, 101, 103, |
| 108 | 118, 145, 85, 40, 62, 88, 95, 87, 73, 64, |
| 109 | 86, 71, 71, 105, 80, 73, 96, 92, 85, 90, |
| 110 | 81, 86, 105, 100, 89, 78, 102, 114, 95, 98, |
| 111 | 69, 70, 108, 112, 111, 90, 104, 137, 143, 160, |
| 112 | 145, 121, 98, 86, 91, 87, 115, 123, 109, 99, |
| 113 | 85, 120, 131, 116, 125, 144, 153, 111, 98, 110, |
| 114 | 93, 89, 101, 137, 155, 142, 108, 94, 136, 145, |
| 115 | 129, 129, 122, 109, 90, 76, 81, 110, 119, 96, |
| 116 | 95, 102, 105, 111, 90, 89, 111, 115, 86, 51, |
| 117 | 107, 140, 105, 105, 110, 142, 125, 76, 75, 69, |
| 118 | 65, 52, 61, 69, 55, 42, 47, 58, 37, 35, |
| 119 | 24, 20, 44, 22, 16, 26, 6, 3, 4, 23, |
| 120 | 60, 51, 30, 12, 24, 31, -9, -16, -13, 13, |
| 121 | 19, 9, 37, 55, 70, 36, 23, 57, 45, 33, |
| 122 | 50, 59, 18, 11, 62, 74, 52, 8, -3, 26, |
| 123 | 51, 48, -5, -9, 12, -7, -12, -5, 28, 41, |
| 124 | -2, -30, -13, 31, 33, -12, -22, -8, -15, -17, |
| 125 | 2, -6, -25, -27, -24, -8, 4, -9, -52, -47, |
| 126 | -9, -32, -45, -5, 41, 15, -32, -14, 2, -1, |
| 127 | -10, -30, -32, -25, -21, -17, -14, 8, -4, -13, |
| 128 | 34, 18, -36, -38, -18, -19, -28, -17, -14, -16, |
| 129 | -2, -20, -27, 12, 11, -17, -33, -12, -22, -64, |
| 130 | -42, -26, -23, -22, -37, -51, -53, -30, -18, -48, |
| 131 | -69, -38, -54, -96, -72, -49, -50, -57, -41, -22, |
| 132 | -43, -64, -54, -23, -49, -69, -41, -44, -42, -49, |
| 133 | -40, -26, -54, -50, -38, -49, -70, -94, -89, -69, |
| 134 | -56, -65, -71, -47, -39, -49, -79, -91, -56, -46, |
| 135 | -62, -86, -64, -32, -47, -50, -71, -77, -65, -68, |
| 136 | -52, -51, -61, -67, -61, -81, -93, -52, -59, -62, |
| 137 | -51, -75, -76, -50, -32, -54, -68, -70, -43, 1, |
| 138 | -42, -92, -80, -41, -38, -79, -69, -49, -82, -122, |
| 139 | -93, -21, -24, -61, -70, -73, -62, -74, -69, -43, |
| 140 | -25, -15, -43, -23, -26, -69, -44, -12, 1, -51, |
| 141 | -78, -13, 3, -53, -105, -72, -24, -62, -66, -31, |
| 142 | -40, -65, -86, -64, -44, -55, -63, -61, -37, -41, |
| 143 | }; |
| 144 | |
| 145 | // Golden audio ops mfcc output for the above wav. |
| 146 | const std::vector<float> testWavMfcc |
| 147 | { |
| 148 | -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072, |
| 149 | }; |
| 150 | |
| 151 | std::vector<float> testWavFloat(640); |
| 152 | constexpr float normaliser = 1.0/(1u<<15u); |
| 153 | std::transform(testWav.begin(), testWav.end(), testWavFloat.begin(), |
| 154 | std::bind1st(std::multiplies<float>(), normaliser)); |
| 155 | |
| 156 | const float DsCNNInputQuantizationScale = 1.107164; |
| 157 | const int DsCNNInputQuantizationOffset = 95; |
| 158 | |
| 159 | std::map<int,std::string> labels = |
| 160 | { |
| 161 | {0,"silence"}, |
| 162 | {1, "unknown"}, |
| 163 | { 2, "yes"}, |
| 164 | { 3,"no"}, |
| 165 | { 4, "up"}, |
| 166 | { 5, "down"}, |
| 167 | { 6, "left"}, |
| 168 | { 7, "right"}, |
| 169 | { 8, "on"}, |
| 170 | { 9, "off"}, |
| 171 | { 10, "stop"}, |
| 172 | {11, "go"} |
| 173 | }; |
| 174 | common::PipelineOptions options; |
| 175 | options.m_ModelFilePath = GetResourceFilePath("ds_cnn_clustered_int8.tflite"); |
| 176 | options.m_ModelName = "DS_CNN_CLUSTERED_INT8"; |
| 177 | options.m_backends = {"CpuAcc", "CpuRef"}; |
| 178 | kws::IPipelinePtr kwsPipeline = kws::CreatePipeline(options); |
| 179 | |
| 180 | CHECK(kwsPipeline->getInputSamplesSize() == 16000); |
| 181 | std::vector<int8_t> expectedWavMfcc; |
| 182 | for(auto& i : testWavMfcc) |
| 183 | { |
| 184 | expectedWavMfcc.push_back( |
| 185 | (i + DsCNNInputQuantizationScale * DsCNNInputQuantizationOffset) / DsCNNInputQuantizationScale); |
| 186 | } |
| 187 | |
| 188 | SECTION("Pre-processing") |
| 189 | { |
| 190 | testWavFloat.resize(16000); |
| 191 | expectedWavMfcc.resize(49 * 10); |
| 192 | std::vector<int8_t> preprocessedData = kwsPipeline->PreProcessing(testWavFloat); |
| 193 | CHECK(preprocessedData.size() == expectedWavMfcc.size()); |
| 194 | for(int i = 0; i < 10; ++i) |
| 195 | { |
| 196 | CHECK(expectedWavMfcc[i] == Approx(preprocessedData[i]).margin(1)); |
| 197 | } |
| 198 | } |
| 199 | |
| 200 | SECTION("Execute inference") |
| 201 | { |
| 202 | common::InferenceResults<int8_t> result; |
| 203 | std::vector<int8_t> IFM(std::begin(ifm0_kws), std::end(ifm0_kws)); |
| 204 | kwsPipeline->Inference(IFM, result); |
| 205 | std::vector<int8_t> OFM(std::begin(ofm0_kws), std::end(ofm0_kws)); |
| 206 | |
| 207 | CHECK(1 == result.size()); |
| 208 | CHECK(OFM.size() == result[0].size()); |
| 209 | |
| 210 | int count = 0; |
| 211 | for (auto& i : result) |
| 212 | { |
| 213 | for (signed char& j : i) |
| 214 | { |
| 215 | CHECK(j == OFM[count++]); |
| 216 | |
| 217 | } |
| 218 | } |
| 219 | } |
| 220 | |
| 221 | SECTION("Convert inference result to keyword") |
| 222 | { |
| 223 | std::vector< std::vector< int8_t >> modelOutput = {{1, 4, 2, 3, 1, 1, 3, 1, 43, 1, 6, 1}}; |
| 224 | kwsPipeline->PostProcessing(modelOutput, labels, |
| 225 | [](int index, std::string& label, float prob) -> void { |
| 226 | CHECK(index == 8); |
| 227 | CHECK(label == "on"); |
| 228 | }); |
| 229 | } |
| 230 | } |