blob: 9fb87fd3c97b90b70ba27b4f9463526536eca445 [file] [log] [blame]
George Gekov23c26272021-08-16 11:32:10 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include <catch.hpp>
7#include <map>
8#include <cinttypes>
9#include "KeywordSpottingPipeline.hpp"
10#include "DsCNNPreprocessor.hpp"
11
12static std::string GetResourceFilePath(const std::string& filename)
13{
14 std::string testResources = TEST_RESOURCE_DIR;
15 if (testResources.empty())
16 {
17 throw std::invalid_argument("Invalid test resources directory provided");
18 }
19 else
20 {
21 if(testResources.back() != '/')
22 {
23 return testResources + "/" + filename;
24 }
25 else
26 {
27 return testResources + filename;
28 }
29 }
30}
31
32TEST_CASE("Test Keyword spotting pipeline")
33{
34 const int8_t ifm0_kws [] =
35 {
36 -0x1b, 0x4f, 0x7a, -0x55, 0x6, -0x11, 0x6e, -0x6, 0x67, -0x7e, -0xd, 0x6, 0x49, 0x79, -0x1e, 0xe,
37 0x1d, 0x6e, 0x6f, 0x6f, -0x2e, -0x4b, 0x2, -0x3e, 0x40, -0x4b, -0x7, 0x31, -0x38, -0x64, -0x28,
38 0xc, -0x1d, 0xf, 0x1c, 0x5a, -0x4b, 0x56, 0x7e, 0x9, -0x29, 0x13, -0x65, -0xa, 0x34, -0x59, 0x41,
39 -0x6f, 0x75, 0x67, -0x5f, 0x17, 0x4a, -0x76, -0x7a, 0x49, -0x19, -0x41, 0x78, 0x40, 0x44, 0xe,
40 -0x51, -0x5c, 0x3d, 0x24, 0x76, -0x66, -0x11, 0x5e, 0x7b, -0x4, 0x7a, 0x9, 0x13, 0x8, -0x21, -0x11,
41 0x13, 0x7a, 0x25, 0x6, -0x68, 0x6a, -0x30, -0x16, -0x43, -0x27, 0x4c, 0x6b, -0x14, -0x12, -0x5f,
42 0x49, -0x2a, 0x44, 0x57, -0x78, -0x72, 0x62, -0x8, -0x38, -0x73, -0x2, -0x80, 0x79, -0x3f, 0x57,
43 0x9, -0x7e, -0x34, -0x59, 0x19, -0x66, 0x58, -0x3b, -0x69, -0x1a, 0x13, -0x2f, -0x2f, 0x13, 0x35,
44 -0x30, 0x1e, 0x3b, -0x71, 0x67, 0x7d, -0x5d, 0x1a, 0x69, -0x53, -0x38, -0xf, 0x76, 0x2, 0x7e, 0x45,
45 -0xa, 0x59, -0x6b, -0x28, -0x5d, -0x63, -0x7d, -0x3, 0x48, 0x74, -0x75, -0x7a, 0x1f, -0x53, 0x5b,
46 0x4d, -0x18, -0x4a, 0x39, -0x52, 0x5a, -0x6b, -0x41, -0x3e, -0x61, -0x80, -0x52, 0x67, 0x71, -0x47,
47 0x79, -0x41, 0x3a, -0x8, -0x1f, 0x4d, -0x7, 0x5b, 0x6b, -0x1b, -0x8, -0x20, -0x21, 0x7c, -0x74,
48 0x25, -0x68, -0xe, -0x7e, -0x45, -0x28, 0x45, -0x1a, -0x39, 0x78, 0x11, 0x48, -0x6b, -0x7b, -0x43,
49 -0x21, 0x38, 0x46, 0x7c, -0x5d, 0x59, 0x53, -0x3f, -0x15, 0x59, -0x17, 0x75, 0x2f, 0x7c, 0x68, 0x6a,
50 0x0, -0x10, 0x5b, 0x61, 0x36, -0x41, 0x33, 0x23, -0x80, -0x1d, -0xb, -0x56, 0x2d, 0x68, -0x68,
51 0x2f, 0x48, -0x5d, -0x44, 0x64, -0x27, 0x68, -0x13, 0x39, -0x3f, 0x18, 0x31, 0x15, -0x78, -0x2,
52 0x72, 0x60, 0x59, -0x30, -0x22, 0x73, 0x61, 0x76, -0x4, -0x62, -0x64, -0x80, -0x32, -0x16, 0x51,
53 -0x2, -0x70, 0x71, 0x3f, -0x5f, -0x35, -0x3c, 0x79, 0x48, 0x61, 0x5b, -0x20, -0x1e, -0x68, -0x1c,
54 0x6c, 0x3a, 0x28, -0x36, -0x3e, 0x5f, -0x75, -0x73, 0x1e, 0x75, -0x66, -0x22, 0x20, -0x64, 0x67,
55 0x36, 0x14, 0x37, -0xa, -0xe, 0x8, -0x37, -0x43, 0x21, -0x8, 0x54, 0x1, 0x34, -0x2c, -0x73, -0x11,
56 -0x48, -0x1c, -0x40, 0x14, 0x4e, -0x53, 0x25, 0x5e, 0x14, 0x4f, 0x7c, 0x6d, -0x61, -0x38, 0x35,
57 -0x5a, -0x44, 0x12, 0x52, -0x60, 0x22, -0x1c, -0x8, -0x4, -0x6b, -0x71, 0x43, 0xb, 0x7b, -0x7,
58 -0x3c, -0x3b, -0x40, -0xd, 0x44, 0x6, 0x30, 0x38, 0x57, 0x1f, -0x7, 0x2, 0x4f, 0x64, 0x7c, -0x3,
59 -0x13, -0x71, -0x45, -0x53, -0x52, 0x2b, -0x11, -0x1d, -0x2, -0x29, -0x37, 0x3d, 0x19, 0x76, 0x18,
60 0x1d, 0x12, -0x29, -0x5e, -0x54, -0x48, 0x5d, -0x41, -0x3f, 0x7e, -0x2a, 0x41, 0x57, -0x65, -0x15,
61 0x12, 0x1f, -0x57, 0x79, -0x64, 0x3a, -0x2f, 0x7f, -0x6c, 0xa, 0x52, -0x1f, -0x41, 0x6e, -0x4b,
62 0x3d, -0x1b, -0x42, 0x22, -0x3c, -0x35, -0xf, 0xc, 0x32, -0x15, -0x68, -0x21, 0x0, -0x16, 0x14,
63 -0x10, -0x5b, 0x2f, 0x21, 0x41, -0x8, -0x12, -0xa, 0x10, 0xf, 0x7e, -0x76, -0x1d, 0x2b, -0x49,
64 0x42, -0x25, -0x78, -0x69, -0x2c, 0x3f, 0xc, 0x52, 0x6d, 0x2e, -0x13, 0x76, 0x37, -0x36, -0x51,
65 -0x5, -0x63, -0x4f, 0x1c, 0x6b, -0x4b, 0x71, -0x12, 0x72, -0x3f,-0x4a, 0xf, 0x3a, -0xd, 0x38, 0x3b,
66 -0x5d, 0x75, -0x43, -0x10, -0xa, -0x7a, 0x1a, -0x44, 0x1c, 0x6a, 0x43, -0x1b, -0x35, 0x7d, -0x2c,
67 -0x10, 0x5b, -0x42, -0x4f, 0x69, 0x1f, 0x1b, -0x64, -0x21, 0x19, -0x5d, 0x2e, -0x2a, -0x65, -0x13,
68 -0x70, -0x6e
69 };
70
71 const int8_t ofm0_kws [] =
72 {
73 -0x80, 0x7f, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80, -0x80
74 };
75
76 // First 640 samples from yes.wav.
77 std::vector<int16_t> testWav = std::vector<int16_t>
78 {
79 139, 143, 164, 163, 157, 156, 151, 148, 172, 171,
80 165, 169, 149, 142, 145, 147, 166, 146, 112, 132,
81 132, 136, 165, 176, 176, 152, 138, 158, 179, 185,
82 183, 148, 121, 130, 167, 204, 163, 132, 165, 184,
83 193, 205, 210, 204, 195, 178, 168, 197, 207, 201,
84 197, 177, 185, 196, 191, 198, 196, 183, 193, 181,
85 157, 170, 167, 159, 164, 152, 146, 167, 180, 171,
86 194, 232, 204, 173, 171, 172, 184, 169, 175, 199,
87 200, 195, 185, 214, 214, 193, 196, 191, 204, 191,
88 172, 187, 183, 192, 203, 172, 182, 228, 232, 205,
89 177, 174, 191, 210, 210, 211, 197, 177, 198, 217,
90 233, 236, 203, 191, 169, 145, 149, 161, 198, 206,
91 176, 137, 142, 181, 200, 215, 201, 188, 166, 162,
92 184, 155, 135, 132, 126, 142, 169, 184, 172, 156,
93 132, 119, 150, 147, 154, 160, 125, 130, 137, 154,
94 161, 168, 195, 182, 160, 134, 138, 146, 130, 120,
95 101, 122, 137, 118, 117, 131, 145, 140, 146, 148,
96 148, 168, 159, 134, 114, 114, 130, 147, 147, 134,
97 125, 98, 107, 127, 99, 79, 84, 107, 117, 114,
98 93, 92, 127, 112, 109, 110, 96, 118, 97, 87,
99 110, 95, 128, 153, 147, 165, 146, 106, 101, 137,
100 139, 96, 73, 90, 91, 51, 69, 102, 100, 103,
101 96, 101, 123, 107, 82, 89, 118, 127, 99, 100,
102 111, 97, 111, 123, 106, 121, 133, 103, 100, 88,
103 85, 111, 114, 125, 102, 91, 97, 84, 139, 157,
104 109, 66, 72, 129, 111, 90, 127, 126, 101, 109,
105 142, 138, 129, 159, 140, 80, 74, 78, 76, 98,
106 68, 42, 106, 143, 112, 102, 115, 114, 82, 75,
107 92, 80, 110, 114, 66, 86, 119, 101, 101, 103,
108 118, 145, 85, 40, 62, 88, 95, 87, 73, 64,
109 86, 71, 71, 105, 80, 73, 96, 92, 85, 90,
110 81, 86, 105, 100, 89, 78, 102, 114, 95, 98,
111 69, 70, 108, 112, 111, 90, 104, 137, 143, 160,
112 145, 121, 98, 86, 91, 87, 115, 123, 109, 99,
113 85, 120, 131, 116, 125, 144, 153, 111, 98, 110,
114 93, 89, 101, 137, 155, 142, 108, 94, 136, 145,
115 129, 129, 122, 109, 90, 76, 81, 110, 119, 96,
116 95, 102, 105, 111, 90, 89, 111, 115, 86, 51,
117 107, 140, 105, 105, 110, 142, 125, 76, 75, 69,
118 65, 52, 61, 69, 55, 42, 47, 58, 37, 35,
119 24, 20, 44, 22, 16, 26, 6, 3, 4, 23,
120 60, 51, 30, 12, 24, 31, -9, -16, -13, 13,
121 19, 9, 37, 55, 70, 36, 23, 57, 45, 33,
122 50, 59, 18, 11, 62, 74, 52, 8, -3, 26,
123 51, 48, -5, -9, 12, -7, -12, -5, 28, 41,
124 -2, -30, -13, 31, 33, -12, -22, -8, -15, -17,
125 2, -6, -25, -27, -24, -8, 4, -9, -52, -47,
126 -9, -32, -45, -5, 41, 15, -32, -14, 2, -1,
127 -10, -30, -32, -25, -21, -17, -14, 8, -4, -13,
128 34, 18, -36, -38, -18, -19, -28, -17, -14, -16,
129 -2, -20, -27, 12, 11, -17, -33, -12, -22, -64,
130 -42, -26, -23, -22, -37, -51, -53, -30, -18, -48,
131 -69, -38, -54, -96, -72, -49, -50, -57, -41, -22,
132 -43, -64, -54, -23, -49, -69, -41, -44, -42, -49,
133 -40, -26, -54, -50, -38, -49, -70, -94, -89, -69,
134 -56, -65, -71, -47, -39, -49, -79, -91, -56, -46,
135 -62, -86, -64, -32, -47, -50, -71, -77, -65, -68,
136 -52, -51, -61, -67, -61, -81, -93, -52, -59, -62,
137 -51, -75, -76, -50, -32, -54, -68, -70, -43, 1,
138 -42, -92, -80, -41, -38, -79, -69, -49, -82, -122,
139 -93, -21, -24, -61, -70, -73, -62, -74, -69, -43,
140 -25, -15, -43, -23, -26, -69, -44, -12, 1, -51,
141 -78, -13, 3, -53, -105, -72, -24, -62, -66, -31,
142 -40, -65, -86, -64, -44, -55, -63, -61, -37, -41,
143 };
144
145 // Golden audio ops mfcc output for the above wav.
146 const std::vector<float> testWavMfcc
147 {
148 -22.67135, -0.61615, 2.07233, 0.58137, 1.01655, 0.85816, 0.46039, 0.03393, 1.16511, 0.0072,
149 };
150
151 std::vector<float> testWavFloat(640);
152 constexpr float normaliser = 1.0/(1u<<15u);
153 std::transform(testWav.begin(), testWav.end(), testWavFloat.begin(),
154 std::bind1st(std::multiplies<float>(), normaliser));
155
156 const float DsCNNInputQuantizationScale = 1.107164;
157 const int DsCNNInputQuantizationOffset = 95;
158
159 std::map<int,std::string> labels =
160 {
161 {0,"silence"},
162 {1, "unknown"},
163 { 2, "yes"},
164 { 3,"no"},
165 { 4, "up"},
166 { 5, "down"},
167 { 6, "left"},
168 { 7, "right"},
169 { 8, "on"},
170 { 9, "off"},
171 { 10, "stop"},
172 {11, "go"}
173 };
174 common::PipelineOptions options;
175 options.m_ModelFilePath = GetResourceFilePath("ds_cnn_clustered_int8.tflite");
176 options.m_ModelName = "DS_CNN_CLUSTERED_INT8";
177 options.m_backends = {"CpuAcc", "CpuRef"};
178 kws::IPipelinePtr kwsPipeline = kws::CreatePipeline(options);
179
180 CHECK(kwsPipeline->getInputSamplesSize() == 16000);
181 std::vector<int8_t> expectedWavMfcc;
182 for(auto& i : testWavMfcc)
183 {
184 expectedWavMfcc.push_back(
185 (i + DsCNNInputQuantizationScale * DsCNNInputQuantizationOffset) / DsCNNInputQuantizationScale);
186 }
187
188 SECTION("Pre-processing")
189 {
190 testWavFloat.resize(16000);
191 expectedWavMfcc.resize(49 * 10);
192 std::vector<int8_t> preprocessedData = kwsPipeline->PreProcessing(testWavFloat);
193 CHECK(preprocessedData.size() == expectedWavMfcc.size());
194 for(int i = 0; i < 10; ++i)
195 {
196 CHECK(expectedWavMfcc[i] == Approx(preprocessedData[i]).margin(1));
197 }
198 }
199
200 SECTION("Execute inference")
201 {
202 common::InferenceResults<int8_t> result;
203 std::vector<int8_t> IFM(std::begin(ifm0_kws), std::end(ifm0_kws));
204 kwsPipeline->Inference(IFM, result);
205 std::vector<int8_t> OFM(std::begin(ofm0_kws), std::end(ofm0_kws));
206
207 CHECK(1 == result.size());
208 CHECK(OFM.size() == result[0].size());
209
210 int count = 0;
211 for (auto& i : result)
212 {
213 for (signed char& j : i)
214 {
215 CHECK(j == OFM[count++]);
216
217 }
218 }
219 }
220
221 SECTION("Convert inference result to keyword")
222 {
223 std::vector< std::vector< int8_t >> modelOutput = {{1, 4, 2, 3, 1, 1, 3, 1, 43, 1, 6, 1}};
224 kwsPipeline->PostProcessing(modelOutput, labels,
225 [](int index, std::string& label, float prob) -> void {
226 CHECK(index == 8);
227 CHECK(label == "on");
228 });
229 }
230}