blob: 872d323b2b59d1c23248f0fb8f085b1c5639ef57 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "InputFiles.hpp"
20#include "Classifier.hpp"
21#include "DsCnnModel.hpp"
22#include "hal.h"
23#include "DsCnnMfcc.hpp"
24#include "AudioUtils.hpp"
25#include "UseCaseCommonUtils.hpp"
26#include "KwsResult.hpp"
27
28#include <vector>
29#include <functional>
30
31using KwsClassifier = arm::app::Classifier;
32
33namespace arm {
34namespace app {
35
36 /**
37 * @brief Helper function to increment current audio clip index.
38 * @param[in,out] ctx Pointer to the application context object.
39 **/
40 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx);
41
42 /**
43 * @brief Helper function to set the audio clip index.
44 * @param[in,out] ctx Pointer to the application context object.
45 * @param[in] idx Value to be set.
46 * @return true if index is set, false otherwise.
47 **/
48 static bool _SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
49
50 /**
51 * @brief Presents inference results using the data presentation
52 * object.
53 * @param[in] platform Reference to the hal platform object.
54 * @param[in] results Vector of classification results to be displayed.
55 * @param[in] infTimeMs Inference time in milliseconds, if available,
56 * otherwise, this can be passed in as 0.
57 * @return true if successful, false otherwise.
58 **/
59 static bool _PresentInferenceResult(hal_platform& platform,
60 const std::vector<arm::app::kws::KwsResult>& results);
61
62 /**
63 * @brief Returns a function to perform feature calculation and populates input tensor data with
64 * MFCC data.
65 *
66 * Input tensor data type check is performed to choose correct MFCC feature data type.
67 * If tensor has an integer data type then original features are quantised.
68 *
69 * Warning: MFCC calculator provided as input must have the same life scope as returned function.
70 *
71 * @param[in] mfcc MFCC feature calculator.
72 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
73 * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
74 * @return Function to be called providing audio sample and sliding window index.
75 */
76 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
77 GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
78 TfLiteTensor* inputTensor,
79 size_t cacheSize);
80
81 /* Audio inference handler. */
82 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
83 {
84 auto& platform = ctx.Get<hal_platform&>("platform");
85
86 constexpr uint32_t dataPsnTxtInfStartX = 20;
87 constexpr uint32_t dataPsnTxtInfStartY = 40;
88 constexpr int minTensorDims = static_cast<int>(
89 (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
90 arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
91
92 platform.data_psn->clear(COLOR_BLACK);
93
94 auto& model = ctx.Get<Model&>("model");
95
96 /* If the request has a valid size, set the audio index. */
97 if (clipIndex < NUMBER_OF_FILES) {
98 if (!_SetAppCtxClipIdx(ctx, clipIndex)) {
99 return false;
100 }
101 }
102 if (!model.IsInited()) {
103 printf_err("Model is not initialised! Terminating processing.\n");
104 return false;
105 }
106
107 const auto frameLength = ctx.Get<int>("frameLength");
108 const auto frameStride = ctx.Get<int>("frameStride");
109 const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
110 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
111
112 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
113 TfLiteTensor* inputTensor = model.GetInputTensor(0);
114
115 if (!inputTensor->dims) {
116 printf_err("Invalid input tensor dims\n");
117 return false;
118 } else if (inputTensor->dims->size < minTensorDims) {
119 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
120 return false;
121 }
122
123 TfLiteIntArray* inputShape = model.GetInputShape(0);
124 const uint32_t kNumCols = inputShape->data[arm::app::DsCnnModel::ms_inputColsIdx];
125 const uint32_t kNumRows = inputShape->data[arm::app::DsCnnModel::ms_inputRowsIdx];
126
127 audio::DsCnnMFCC mfcc = audio::DsCnnMFCC(kNumCols, frameLength);
128 mfcc.Init();
129
130 /* Deduce the data length required for 1 inference from the network parameters. */
131 auto audioDataWindowSize = kNumRows * frameStride + (frameLength - frameStride);
132 auto mfccWindowSize = frameLength;
133 auto mfccWindowStride = frameStride;
134
135 /* We choose to move by half the window size => for a 1 second window size
136 * there is an overlap of 0.5 seconds. */
137 auto audioDataStride = audioDataWindowSize / 2;
138
139 /* To have the previously calculated features re-usable, stride must be multiple
140 * of MFCC features window stride. */
141 if (0 != audioDataStride % mfccWindowStride) {
142
143 /* Reduce the stride. */
144 audioDataStride -= audioDataStride % mfccWindowStride;
145 }
146
147 auto nMfccVectorsInAudioStride = audioDataStride/mfccWindowStride;
148
149 /* We expect to be sampling 1 second worth of data at a time.
150 * NOTE: This is only used for time stamp calculation. */
151 const float secondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
152
153 do {
154 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
155
156 /* Creating a mfcc features sliding window for the data required for 1 inference. */
157 auto audioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
158 get_audio_array(currentIndex),
159 audioDataWindowSize, mfccWindowSize,
160 mfccWindowStride);
161
162 /* Creating a sliding window through the whole audio clip. */
163 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
164 get_audio_array(currentIndex),
165 get_audio_array_size(currentIndex),
166 audioDataWindowSize, audioDataStride);
167
168 /* Calculate number of the feature vectors in the window overlap region.
169 * These feature vectors will be reused.*/
170 auto numberOfReusedFeatureVectors = audioMFCCWindowSlider.TotalStrides() + 1
171 - nMfccVectorsInAudioStride;
172
173 /* Construct feature calculation function. */
174 auto mfccFeatureCalc = GetFeatureCalculator(mfcc, inputTensor,
175 numberOfReusedFeatureVectors);
176
177 if (!mfccFeatureCalc){
178 return false;
179 }
180
181 /* Declare a container for results. */
182 std::vector<arm::app::kws::KwsResult> results;
183
184 /* Display message on the LCD - inference running. */
185 std::string str_inf{"Running inference... "};
186 platform.data_psn->present_data_text(
187 str_inf.c_str(), str_inf.size(),
188 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
189 info("Running inference on audio clip %u => %s\n", currentIndex,
190 get_filename(currentIndex));
191
192 /* Start sliding through audio clip. */
193 while (audioDataSlider.HasNext()) {
194 const int16_t *inferenceWindow = audioDataSlider.Next();
195
196 /* We moved to the next window - set the features sliding to the new address. */
197 audioMFCCWindowSlider.Reset(inferenceWindow);
198
199 /* The first window does not have cache ready. */
200 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
201
202 /* Start calculating features inside one audio sliding window. */
203 while (audioMFCCWindowSlider.HasNext()) {
204 const int16_t *mfccWindow = audioMFCCWindowSlider.Next();
205 std::vector<int16_t> mfccAudioData = std::vector<int16_t>(mfccWindow,
206 mfccWindow + mfccWindowSize);
207 /* Compute features for this window and write them to input tensor. */
208 mfccFeatureCalc(mfccAudioData,
209 audioMFCCWindowSlider.Index(),
210 useCache,
211 nMfccVectorsInAudioStride);
212 }
213
214 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
215 audioDataSlider.TotalStrides() + 1);
216
217 /* Run inference over this audio clip sliding window. */
218 arm::app::RunInference(platform, model);
219
220 std::vector<ClassificationResult> classificationResult;
221 auto& classifier = ctx.Get<KwsClassifier&>("classifier");
222 classifier.GetClassificationResults(outputTensor, classificationResult,
223 ctx.Get<std::vector<std::string>&>("labels"), 1);
224
225 results.emplace_back(kws::KwsResult(classificationResult,
226 audioDataSlider.Index() * secondsPerSample * audioDataStride,
227 audioDataSlider.Index(), scoreThreshold));
228
229#if VERIFY_TEST_OUTPUT
230 arm::app::DumpTensor(outputTensor);
231#endif /* VERIFY_TEST_OUTPUT */
232 } /* while (audioDataSlider.HasNext()) */
233
234 /* Erase. */
235 str_inf = std::string(str_inf.size(), ' ');
236 platform.data_psn->present_data_text(
237 str_inf.c_str(), str_inf.size(),
238 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
239
240 ctx.Set<std::vector<arm::app::kws::KwsResult>>("results", results);
241
242 if (!_PresentInferenceResult(platform, results)) {
243 return false;
244 }
245
246 _IncrementAppCtxClipIdx(ctx);
247
248 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
249
250 return true;
251 }
252
253 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx)
254 {
255 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
256
257 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
258 ctx.Set<uint32_t>("clipIndex", 0);
259 return;
260 }
261 ++curAudioIdx;
262 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
263 }
264
265 static bool _SetAppCtxClipIdx(ApplicationContext& ctx, const uint32_t idx)
266 {
267 if (idx >= NUMBER_OF_FILES) {
268 printf_err("Invalid idx %u (expected less than %u)\n",
269 idx, NUMBER_OF_FILES);
270 return false;
271 }
272 ctx.Set<uint32_t>("clipIndex", idx);
273 return true;
274 }
275
276 static bool _PresentInferenceResult(hal_platform& platform,
277 const std::vector<arm::app::kws::KwsResult>& results)
278 {
279 constexpr uint32_t dataPsnTxtStartX1 = 20;
280 constexpr uint32_t dataPsnTxtStartY1 = 30;
281 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
282
283 platform.data_psn->set_text_color(COLOR_GREEN);
284
285 /* Display each result */
286 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
287
288 for (uint32_t i = 0; i < results.size(); ++i) {
289
290 std::string topKeyword{"<none>"};
291 float score = 0.f;
292
293 if (results[i].m_resultVec.size()) {
294 topKeyword = results[i].m_resultVec[0].m_label;
295 score = results[i].m_resultVec[0].m_normalisedVal;
296 }
297
298 std::string resultStr =
299 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
300 std::string{"s: "} + topKeyword + std::string{" ("} +
301 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
302
303 platform.data_psn->present_data_text(
304 resultStr.c_str(), resultStr.size(),
305 dataPsnTxtStartX1, rowIdx1, false);
306 rowIdx1 += dataPsnTxtYIncr;
307
308 info("For timestamp: %f (inference #: %u); threshold: %f\n",
309 results[i].m_timeStamp, results[i].m_inferenceNumber,
310 results[i].m_threshold);
311 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
312 info("\t\tlabel @ %u: %s, score: %f\n", j,
313 results[i].m_resultVec[j].m_label.c_str(),
314 results[i].m_resultVec[j].m_normalisedVal);
315 }
316 }
317
318 return true;
319 }
320
321 /**
322 * @brief Generic feature calculator factory.
323 *
324 * Returns lambda function to compute features using features cache.
325 * Real features math is done by a lambda function provided as a parameter.
326 * Features are written to input tensor memory.
327 *
328 * @tparam T Feature vector type.
329 * @param inputTensor Model input tensor pointer.
330 * @param cacheSize Number of feature vectors to cache. Defined by the sliding window overlap.
331 * @param compute Features calculator function.
332 * @return Lambda function to compute features.
333 */
334 template<class T>
335 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
336 _FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
337 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
338 {
339 /* Feature cache to be captured by lambda function. */
340 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
341
342 return [=](std::vector<int16_t>& audioDataWindow,
343 size_t index,
344 bool useCache,
345 size_t featuresOverlapIndex)
346 {
347 T *tensorData = tflite::GetTensorData<T>(inputTensor);
348 std::vector<T> features;
349
350 /* Reuse features from cache if cache is ready and sliding windows overlap.
351 * Overlap is in the beginning of sliding window with a size of a feature cache. */
352 if (useCache && index < featureCache.size()) {
353 features = std::move(featureCache[index]);
354 } else {
355 features = std::move(compute(audioDataWindow));
356 }
357 auto size = features.size();
358 auto sizeBytes = sizeof(T) * size;
359 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
360
361 /* Start renewing cache as soon iteration goes out of the windows overlap. */
362 if (index >= featuresOverlapIndex) {
363 featureCache[index - featuresOverlapIndex] = std::move(features);
364 }
365 };
366 }
367
368 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
369 _FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
370 size_t cacheSize,
371 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
372
373 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
374 _FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
375 size_t cacheSize,
376 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
377
378 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
379 _FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
380 size_t cacheSize,
381 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
382
383 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
384 _FeatureCalc<float>(TfLiteTensor *inputTensor,
385 size_t cacheSize,
386 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
387
388
389 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
390 GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
391 {
392 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
393
394 TfLiteQuantization quant = inputTensor->quantization;
395
396 if (kTfLiteAffineQuantization == quant.type) {
397
398 auto *quantParams = (TfLiteAffineQuantization *) quant.params;
399 const float quantScale = quantParams->scale->data[0];
400 const int quantOffset = quantParams->zero_point->data[0];
401
402 switch (inputTensor->type) {
403 case kTfLiteInt8: {
404 mfccFeatureCalc = _FeatureCalc<int8_t>(inputTensor,
405 cacheSize,
406 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
407 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
408 quantScale,
409 quantOffset);
410 }
411 );
412 break;
413 }
414 case kTfLiteUInt8: {
415 mfccFeatureCalc = _FeatureCalc<uint8_t>(inputTensor,
416 cacheSize,
417 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
418 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
419 quantScale,
420 quantOffset);
421 }
422 );
423 break;
424 }
425 case kTfLiteInt16: {
426 mfccFeatureCalc = _FeatureCalc<int16_t>(inputTensor,
427 cacheSize,
428 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
429 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
430 quantScale,
431 quantOffset);
432 }
433 );
434 break;
435 }
436 default:
437 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
438 }
439
440
441 } else {
442 mfccFeatureCalc = mfccFeatureCalc = _FeatureCalc<float>(inputTensor,
443 cacheSize,
444 [&mfcc](std::vector<int16_t>& audioDataWindow) {
445 return mfcc.MfccCompute(audioDataWindow);
446 });
447 }
448 return mfccFeatureCalc;
449 }
450
451} /* namespace app */
452} /* namespace arm */