blob: c50796fbfaa2d2ef932a0ef8f63f903b6ce7eac7 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "DsCnnModel.hpp"
24#include "DsCnnMfcc.hpp"
25#include "Classifier.hpp"
26#include "KwsResult.hpp"
27#include "Wav2LetterMfcc.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "Wav2LetterPostprocess.hpp"
30#include "AsrResult.hpp"
31#include "AsrClassifier.hpp"
32#include "OutputDecode.hpp"
33
34
35using KwsClassifier = arm::app::Classifier;
36
37namespace arm {
38namespace app {
39
40 enum AsrOutputReductionAxis {
41 AxisRow = 1,
42 AxisCol = 2
43 };
44
45 struct KWSOutput {
46 bool executionSuccess = false;
47 const int16_t* asrAudioStart = nullptr;
48 int32_t asrAudioSamples = 0;
49 };
50
51 /**
52 * @brief Helper function to increment current audio clip index
53 * @param[in,out] ctx pointer to the application context object
54 **/
55 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx);
56
57 /**
58 * @brief Helper function to increment current audio clip index
59 * @param[in,out] ctx pointer to the application context object
60 **/
61 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx);
62
63 /**
64 * @brief Helper function to set the audio clip index
65 * @param[in,out] ctx pointer to the application context object
66 * @param[in] idx value to be set
67 * @return true if index is set, false otherwise
68 **/
69 static bool _SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
70
71 /**
72 * @brief Presents kws inference results using the data presentation
73 * object.
74 * @param[in] platform reference to the hal platform object
75 * @param[in] results vector of classification results to be displayed
76 * @param[in] infTimeMs inference time in milliseconds, if available
77 * Otherwise, this can be passed in as 0.
78 * @return true if successful, false otherwise
79 **/
80 static bool _PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results);
81
82 /**
83 * @brief Presents asr inference results using the data presentation
84 * object.
85 * @param[in] platform reference to the hal platform object
86 * @param[in] results vector of classification results to be displayed
87 * @param[in] infTimeMs inference time in milliseconds, if available
88 * Otherwise, this can be passed in as 0.
89 * @return true if successful, false otherwise
90 **/
91 static bool _PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results);
92
93 /**
94 * @brief Returns a function to perform feature calculation and populates input tensor data with
95 * MFCC data.
96 *
97 * Input tensor data type check is performed to choose correct MFCC feature data type.
98 * If tensor has an integer data type then original features are quantised.
99 *
100 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
101 *
102 * @param[in] mfcc MFCC feature calculator.
103 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
104 * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
105 *
106 * @return function function to be called providing audio sample and sliding window index.
107 **/
108 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
109 GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
110 TfLiteTensor* inputTensor,
111 size_t cacheSize);
112
113 /**
114 * @brief Performs the KWS pipeline.
115 * @param[in,out] ctx pointer to the application context object
116 *
117 * @return KWSOutput struct containing pointer to audio data where ASR should begin
118 * and how much data to process.
119 */
120 static KWSOutput doKws(ApplicationContext& ctx) {
121 constexpr uint32_t dataPsnTxtInfStartX = 20;
122 constexpr uint32_t dataPsnTxtInfStartY = 40;
123
124 constexpr int minTensorDims = static_cast<int>(
125 (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
126 arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
127
128 KWSOutput output;
129
130 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
131 if (!kwsModel.IsInited()) {
132 printf_err("KWS model has not been initialised\n");
133 return output;
134 }
135
136 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
137 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
138 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
139
140 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
141 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
142
143 if (!kwsInputTensor->dims) {
144 printf_err("Invalid input tensor dims\n");
145 return output;
146 } else if (kwsInputTensor->dims->size < minTensorDims) {
147 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
148 return output;
149 }
150
151 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
152 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
153
154 audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength);
155 kwsMfcc.Init();
156
157 /* Deduce the data length required for 1 KWS inference from the network parameters. */
158 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
159 (kwsFrameLength - kwsFrameStride);
160 auto kwsMfccWindowSize = kwsFrameLength;
161 auto kwsMfccWindowStride = kwsFrameStride;
162
163 /* We are choosing to move by half the window size => for a 1 second window size,
164 * this means an overlap of 0.5 seconds. */
165 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
166
167 info("KWS audio data window size %u\n", kwsAudioDataWindowSize);
168
169 /* Stride must be multiple of mfcc features window stride to re-use features. */
170 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
171 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
172 }
173
174 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
175
176 /* We expect to be sampling 1 second worth of data at a time
177 * NOTE: This is only used for time stamp calculation. */
178 const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
179
180 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
181
182 /* Creating a mfcc features sliding window for the data required for 1 inference. */
183 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
184 get_audio_array(currentIndex),
185 kwsAudioDataWindowSize, kwsMfccWindowSize,
186 kwsMfccWindowStride);
187
188 /* Creating a sliding window through the whole audio clip. */
189 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
190 get_audio_array(currentIndex),
191 get_audio_array_size(currentIndex),
192 kwsAudioDataWindowSize, kwsAudioDataStride);
193
194 /* Calculate number of the feature vectors in the window overlap region.
195 * These feature vectors will be reused.*/
196 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
197 - kwsMfccVectorsInAudioStride;
198
199 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
200 numberOfReusedFeatureVectors);
201
202 if (!kwsMfccFeatureCalc){
203 return output;
204 }
205
206 /* Container for KWS results. */
207 std::vector<arm::app::kws::KwsResult> kwsResults;
208
209 /* Display message on the LCD - inference running. */
210 auto& platform = ctx.Get<hal_platform&>("platform");
211 std::string str_inf{"Running KWS inference... "};
212 platform.data_psn->present_data_text(
213 str_inf.c_str(), str_inf.size(),
214 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
215
216 info("Running KWS inference on audio clip %u => %s\n",
217 currentIndex, get_filename(currentIndex));
218
219 /* Start sliding through audio clip. */
220 while (audioDataSlider.HasNext()) {
221 const int16_t* inferenceWindow = audioDataSlider.Next();
222
223 /* We moved to the next window - set the features sliding to the new address. */
224 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
225
226 /* The first window does not have cache ready. */
227 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
228
229 /* Start calculating features inside one audio sliding window. */
230 while (kwsAudioMFCCWindowSlider.HasNext()) {
231 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
232 std::vector<int16_t> kwsMfccAudioData =
233 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
234
235 /* Compute features for this window and write them to input tensor. */
236 kwsMfccFeatureCalc(kwsMfccAudioData,
237 kwsAudioMFCCWindowSlider.Index(),
238 useCache,
239 kwsMfccVectorsInAudioStride);
240 }
241
242 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
243 audioDataSlider.TotalStrides() + 1);
244
245 /* Run inference over this audio clip sliding window. */
246 arm::app::RunInference(platform, kwsModel);
247
248 std::vector<ClassificationResult> kwsClassificationResult;
249 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
250
251 kwsClassifier.GetClassificationResults(
252 kwsOutputTensor, kwsClassificationResult,
253 ctx.Get<std::vector<std::string>&>("kwslabels"), 1);
254
255 kwsResults.emplace_back(
256 kws::KwsResult(
257 kwsClassificationResult,
258 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
259 audioDataSlider.Index(), kwsScoreThreshold)
260 );
261
262 /* Keyword detected. */
263 if (kwsClassificationResult[0].m_labelIdx == ctx.Get<uint32_t>("keywordindex")) {
264 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
265 output.asrAudioSamples = get_audio_array_size(currentIndex) -
266 (audioDataSlider.NextWindowStartIndex() -
267 kwsAudioDataStride + kwsAudioDataWindowSize);
268 break;
269 }
270
271#if VERIFY_TEST_OUTPUT
272 arm::app::DumpTensor(kwsOutputTensor);
273#endif /* VERIFY_TEST_OUTPUT */
274
275 } /* while (audioDataSlider.HasNext()) */
276
277 /* Erase. */
278 str_inf = std::string(str_inf.size(), ' ');
279 platform.data_psn->present_data_text(
280 str_inf.c_str(), str_inf.size(),
281 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
282
283 if (!_PresentInferenceResult(platform, kwsResults)) {
284 return output;
285 }
286
287 output.executionSuccess = true;
288 return output;
289 }
290
291 /**
292 * @brief Performs the ASR pipeline.
293 *
294 * @param ctx[in/out] pointer to the application context object
295 * @param kwsOutput[in] struct containing pointer to audio data where ASR should begin
296 * and how much data to process
297 * @return bool true if pipeline executed without failure
298 */
299 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
300 constexpr uint32_t dataPsnTxtInfStartX = 20;
301 constexpr uint32_t dataPsnTxtInfStartY = 40;
302
303 auto& platform = ctx.Get<hal_platform&>("platform");
304 platform.data_psn->clear(COLOR_BLACK);
305
306 /* Get model reference. */
307 auto& asrModel = ctx.Get<Model&>("asrmodel");
308 if (!asrModel.IsInited()) {
309 printf_err("ASR model has not been initialised\n");
310 return false;
311 }
312
313 /* Get score threshold to be applied for the classifier (post-inference). */
314 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
315
316 /* Dimensions of the tensor should have been verified by the callee. */
317 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
318 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
319 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
320
321 /* Populate ASR MFCC related parameters. */
322 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
323 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
324
325 /* Populate ASR inference context and inner lengths for input. */
326 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
327 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
328
329 /* Make sure the input tensor supports the above context and inner lengths. */
330 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
331 printf_err("ASR input rows not compatible with ctx length %u\n", asrInputCtxLen);
332 return false;
333 }
334
335 /* Audio data stride corresponds to inputInnerLen feature vectors. */
336 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
337 asrMfccParamsWinStride + (asrMfccParamsWinLen);
338 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
339 const float asrAudioParamsSecondsPerSample =
340 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
341
342 /* Get pre/post-processing objects */
343 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
344 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
345
346 /* Set default reduction axis for post-processing. */
347 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
348
349 /* Get the remaining audio buffer and respective size from KWS results. */
350 const int16_t* audioArr = kwsOutput.asrAudioStart;
351 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
352
353 /* Audio clip must have enough samples to produce 1 MFCC feature. */
354 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
355 if (audioArrSize < asrMfccParamsWinLen) {
356 printf_err("Not enough audio samples, minimum needed is %u\n", asrMfccParamsWinLen);
357 return false;
358 }
359
360 /* Initialise an audio slider. */
361 auto audioDataSlider = audio::ASRSlidingWindow<const int16_t>(
362 audioBuffer.data(),
363 audioBuffer.size(),
364 asrAudioParamsWinLen,
365 asrAudioParamsWinStride);
366
367 /* Declare a container for results. */
368 std::vector<arm::app::asr::AsrResult> asrResults;
369
370 /* Display message on the LCD - inference running. */
371 std::string str_inf{"Running ASR inference... "};
372 platform.data_psn->present_data_text(
373 str_inf.c_str(), str_inf.size(),
374 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
375
376 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
377
378 /* Start sliding through audio clip. */
379 while (audioDataSlider.HasNext()) {
380
381 /* If not enough audio see how much can be sent for processing. */
382 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
383 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
384 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
385 }
386
387 const int16_t* asrInferenceWindow = audioDataSlider.Next();
388
389 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
390 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
391
392 Profiler prepProfiler{&platform, "pre-processing"};
393 prepProfiler.StartProfiling();
394
395 /* Calculate MFCCs, deltas and populate the input tensor. */
396 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
397
398 prepProfiler.StopProfiling();
399 std::string prepProfileResults = prepProfiler.GetResultsAndReset();
400 info("%s\n", prepProfileResults.c_str());
401
402 /* Run inference over this audio clip sliding window. */
403 arm::app::RunInference(platform, asrModel);
404
405 /* Post-process. */
406 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
407
408 /* Get results. */
409 std::vector<ClassificationResult> asrClassificationResult;
410 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
411 asrClassifier.GetClassificationResults(
412 asrOutputTensor, asrClassificationResult,
413 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
414
415 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
416 (audioDataSlider.Index() *
417 asrAudioParamsSecondsPerSample *
418 asrAudioParamsWinStride),
419 audioDataSlider.Index(), asrScoreThreshold));
420
421#if VERIFY_TEST_OUTPUT
422 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
423#endif /* VERIFY_TEST_OUTPUT */
424
425 /* Erase */
426 str_inf = std::string(str_inf.size(), ' ');
427 platform.data_psn->present_data_text(
428 str_inf.c_str(), str_inf.size(),
429 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
430 }
431 if (!_PresentInferenceResult(platform, asrResults)) {
432 return false;
433 }
434
435 return true;
436 }
437
438 /* Audio inference classification handler. */
439 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
440 {
441 auto& platform = ctx.Get<hal_platform&>("platform");
442 platform.data_psn->clear(COLOR_BLACK);
443
444 /* If the request has a valid size, set the audio index. */
445 if (clipIndex < NUMBER_OF_FILES) {
446 if (!_SetAppCtxClipIdx(ctx, clipIndex)) {
447 return false;
448 }
449 }
450
451 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
452
453 do {
454 KWSOutput kwsOutput = doKws(ctx);
455 if (!kwsOutput.executionSuccess) {
456 return false;
457 }
458
459 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
460 info("Keyword spotted\n");
461 if(!doAsr(ctx, kwsOutput)) {
462 printf_err("ASR failed");
463 return false;
464 }
465 }
466
467 _IncrementAppCtxClipIdx(ctx);
468
469 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
470
471 return true;
472 }
473
474 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx)
475 {
476 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
477
478 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
479 ctx.Set<uint32_t>("clipIndex", 0);
480 return;
481 }
482 ++curAudioIdx;
483 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
484 }
485
486 static bool _SetAppCtxClipIdx(ApplicationContext& ctx, const uint32_t idx)
487 {
488 if (idx >= NUMBER_OF_FILES) {
489 printf_err("Invalid idx %u (expected less than %u)\n",
490 idx, NUMBER_OF_FILES);
491 return false;
492 }
493 ctx.Set<uint32_t>("clipIndex", idx);
494 return true;
495 }
496
497 static bool _PresentInferenceResult(hal_platform& platform,
498 std::vector<arm::app::kws::KwsResult>& results)
499 {
500 constexpr uint32_t dataPsnTxtStartX1 = 20;
501 constexpr uint32_t dataPsnTxtStartY1 = 30;
502 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
503
504 platform.data_psn->set_text_color(COLOR_GREEN);
505
506 /* Display each result. */
507 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
508
509 for (uint32_t i = 0; i < results.size(); ++i) {
510
511 std::string topKeyword{"<none>"};
512 float score = 0.f;
513
514 if (results[i].m_resultVec.size()) {
515 topKeyword = results[i].m_resultVec[0].m_label;
516 score = results[i].m_resultVec[0].m_normalisedVal;
517 }
518
519 std::string resultStr =
520 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
521 std::string{"s: "} + topKeyword + std::string{" ("} +
522 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
523
524 platform.data_psn->present_data_text(
525 resultStr.c_str(), resultStr.size(),
526 dataPsnTxtStartX1, rowIdx1, 0);
527 rowIdx1 += dataPsnTxtYIncr;
528
529 info("For timestamp: %f (inference #: %u); threshold: %f\n",
530 results[i].m_timeStamp, results[i].m_inferenceNumber,
531 results[i].m_threshold);
532 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
533 info("\t\tlabel @ %u: %s, score: %f\n", j,
534 results[i].m_resultVec[j].m_label.c_str(),
535 results[i].m_resultVec[j].m_normalisedVal);
536 }
537 }
538
539 return true;
540 }
541
542 static bool _PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results)
543 {
544 constexpr uint32_t dataPsnTxtStartX1 = 20;
545 constexpr uint32_t dataPsnTxtStartY1 = 80;
546 constexpr bool allow_multiple_lines = true;
547
548 platform.data_psn->set_text_color(COLOR_GREEN);
549
550 /* Results from multiple inferences should be combined before processing. */
551 std::vector<arm::app::ClassificationResult> combinedResults;
552 for (auto& result : results) {
553 combinedResults.insert(combinedResults.end(),
554 result.m_resultVec.begin(),
555 result.m_resultVec.end());
556 }
557
558 for (auto& result : results) {
559 /* Get the final result string using the decoder. */
560 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
561
562 info("Result for inf %u: %s\n", result.m_inferenceNumber,
563 infResultStr.c_str());
564 }
565
566 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
567
568 platform.data_psn->present_data_text(
569 finalResultStr.c_str(), finalResultStr.size(),
570 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
571
572 info("Final result: %s\n", finalResultStr.c_str());
573 return true;
574 }
575
576 /**
577 * @brief Generic feature calculator factory.
578 *
579 * Returns lambda function to compute features using features cache.
580 * Real features math is done by a lambda function provided as a parameter.
581 * Features are written to input tensor memory.
582 *
583 * @tparam T feature vector type.
584 * @param inputTensor model input tensor pointer.
585 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
586 * @param compute features calculator function.
587 * @return lambda function to compute features.
588 **/
589 template<class T>
590 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
591 _FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
592 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
593 {
594 /* Feature cache to be captured by lambda function. */
595 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
596
597 return [=](std::vector<int16_t>& audioDataWindow,
598 size_t index,
599 bool useCache,
600 size_t featuresOverlapIndex)
601 {
602 T* tensorData = tflite::GetTensorData<T>(inputTensor);
603 std::vector<T> features;
604
605 /* Reuse features from cache if cache is ready and sliding windows overlap.
606 * Overlap is in the beginning of sliding window with a size of a feature cache.
607 */
608 if (useCache && index < featureCache.size()) {
609 features = std::move(featureCache[index]);
610 } else {
611 features = std::move(compute(audioDataWindow));
612 }
613 auto size = features.size();
614 auto sizeBytes = sizeof(T) * size;
615 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
616
617 /* Start renewing cache as soon iteration goes out of the windows overlap. */
618 if (index >= featuresOverlapIndex) {
619 featureCache[index - featuresOverlapIndex] = std::move(features);
620 }
621 };
622 }
623
624 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
625 _FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
626 size_t cacheSize,
627 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
628
629 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
630 _FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
631 size_t cacheSize,
632 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
633
634 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
635 _FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
636 size_t cacheSize,
637 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
638
639 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
640 _FeatureCalc<float>(TfLiteTensor* inputTensor,
641 size_t cacheSize,
642 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
643
644
645 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
646 GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
647 {
648 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
649
650 TfLiteQuantization quant = inputTensor->quantization;
651
652 if (kTfLiteAffineQuantization == quant.type) {
653
654 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
655 const float quantScale = quantParams->scale->data[0];
656 const int quantOffset = quantParams->zero_point->data[0];
657
658 switch (inputTensor->type) {
659 case kTfLiteInt8: {
660 mfccFeatureCalc = _FeatureCalc<int8_t>(inputTensor,
661 cacheSize,
662 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
663 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
664 quantScale,
665 quantOffset);
666 }
667 );
668 break;
669 }
670 case kTfLiteUInt8: {
671 mfccFeatureCalc = _FeatureCalc<uint8_t>(inputTensor,
672 cacheSize,
673 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
674 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
675 quantScale,
676 quantOffset);
677 }
678 );
679 break;
680 }
681 case kTfLiteInt16: {
682 mfccFeatureCalc = _FeatureCalc<int16_t>(inputTensor,
683 cacheSize,
684 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
685 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
686 quantScale,
687 quantOffset);
688 }
689 );
690 break;
691 }
692 default:
693 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
694 }
695
696
697 } else {
698 mfccFeatureCalc = mfccFeatureCalc = _FeatureCalc<float>(inputTensor,
699 cacheSize,
700 [&mfcc](std::vector<int16_t>& audioDataWindow) {
701 return mfcc.MfccCompute(audioDataWindow);
702 });
703 }
704 return mfccFeatureCalc;
705 }
706} /* namespace app */
707} /* namespace arm */