blob: a428210297d114a93fb7ab9b100f849fb341c0a5 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "DsCnnModel.hpp"
24#include "DsCnnMfcc.hpp"
25#include "Classifier.hpp"
26#include "KwsResult.hpp"
27#include "Wav2LetterMfcc.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "Wav2LetterPostprocess.hpp"
30#include "AsrResult.hpp"
31#include "AsrClassifier.hpp"
32#include "OutputDecode.hpp"
33
34
35using KwsClassifier = arm::app::Classifier;
36
37namespace arm {
38namespace app {
39
40 enum AsrOutputReductionAxis {
41 AxisRow = 1,
42 AxisCol = 2
43 };
44
45 struct KWSOutput {
46 bool executionSuccess = false;
47 const int16_t* asrAudioStart = nullptr;
48 int32_t asrAudioSamples = 0;
49 };
50
51 /**
52 * @brief Helper function to increment current audio clip index
53 * @param[in,out] ctx pointer to the application context object
54 **/
55 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx);
56
57 /**
58 * @brief Helper function to increment current audio clip index
59 * @param[in,out] ctx pointer to the application context object
60 **/
61 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx);
62
63 /**
64 * @brief Helper function to set the audio clip index
65 * @param[in,out] ctx pointer to the application context object
66 * @param[in] idx value to be set
67 * @return true if index is set, false otherwise
68 **/
69 static bool _SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
70
71 /**
72 * @brief Presents kws inference results using the data presentation
73 * object.
74 * @param[in] platform reference to the hal platform object
75 * @param[in] results vector of classification results to be displayed
76 * @param[in] infTimeMs inference time in milliseconds, if available
77 * Otherwise, this can be passed in as 0.
78 * @return true if successful, false otherwise
79 **/
80 static bool _PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results);
81
82 /**
83 * @brief Presents asr inference results using the data presentation
84 * object.
85 * @param[in] platform reference to the hal platform object
86 * @param[in] results vector of classification results to be displayed
87 * @param[in] infTimeMs inference time in milliseconds, if available
88 * Otherwise, this can be passed in as 0.
89 * @return true if successful, false otherwise
90 **/
91 static bool _PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results);
92
93 /**
94 * @brief Returns a function to perform feature calculation and populates input tensor data with
95 * MFCC data.
96 *
97 * Input tensor data type check is performed to choose correct MFCC feature data type.
98 * If tensor has an integer data type then original features are quantised.
99 *
100 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
101 *
102 * @param[in] mfcc MFCC feature calculator.
103 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
104 * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
105 *
106 * @return function function to be called providing audio sample and sliding window index.
107 **/
108 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
109 GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
110 TfLiteTensor* inputTensor,
111 size_t cacheSize);
112
113 /**
114 * @brief Performs the KWS pipeline.
115 * @param[in,out] ctx pointer to the application context object
116 *
117 * @return KWSOutput struct containing pointer to audio data where ASR should begin
118 * and how much data to process.
119 */
120 static KWSOutput doKws(ApplicationContext& ctx) {
121 constexpr uint32_t dataPsnTxtInfStartX = 20;
122 constexpr uint32_t dataPsnTxtInfStartY = 40;
123
124 constexpr int minTensorDims = static_cast<int>(
125 (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
126 arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
127
128 KWSOutput output;
129
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100130 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000131 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
132 if (!kwsModel.IsInited()) {
133 printf_err("KWS model has not been initialised\n");
134 return output;
135 }
136
137 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
138 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
139 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
140
141 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
142 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
143
144 if (!kwsInputTensor->dims) {
145 printf_err("Invalid input tensor dims\n");
146 return output;
147 } else if (kwsInputTensor->dims->size < minTensorDims) {
148 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
149 return output;
150 }
151
152 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
153 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
154
155 audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength);
156 kwsMfcc.Init();
157
158 /* Deduce the data length required for 1 KWS inference from the network parameters. */
159 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
160 (kwsFrameLength - kwsFrameStride);
161 auto kwsMfccWindowSize = kwsFrameLength;
162 auto kwsMfccWindowStride = kwsFrameStride;
163
164 /* We are choosing to move by half the window size => for a 1 second window size,
165 * this means an overlap of 0.5 seconds. */
166 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
167
168 info("KWS audio data window size %u\n", kwsAudioDataWindowSize);
169
170 /* Stride must be multiple of mfcc features window stride to re-use features. */
171 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
172 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
173 }
174
175 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
176
177 /* We expect to be sampling 1 second worth of data at a time
178 * NOTE: This is only used for time stamp calculation. */
179 const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
180
181 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
182
183 /* Creating a mfcc features sliding window for the data required for 1 inference. */
184 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
185 get_audio_array(currentIndex),
186 kwsAudioDataWindowSize, kwsMfccWindowSize,
187 kwsMfccWindowStride);
188
189 /* Creating a sliding window through the whole audio clip. */
190 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
191 get_audio_array(currentIndex),
192 get_audio_array_size(currentIndex),
193 kwsAudioDataWindowSize, kwsAudioDataStride);
194
195 /* Calculate number of the feature vectors in the window overlap region.
196 * These feature vectors will be reused.*/
197 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
198 - kwsMfccVectorsInAudioStride;
199
200 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
201 numberOfReusedFeatureVectors);
202
203 if (!kwsMfccFeatureCalc){
204 return output;
205 }
206
207 /* Container for KWS results. */
208 std::vector<arm::app::kws::KwsResult> kwsResults;
209
210 /* Display message on the LCD - inference running. */
211 auto& platform = ctx.Get<hal_platform&>("platform");
212 std::string str_inf{"Running KWS inference... "};
213 platform.data_psn->present_data_text(
214 str_inf.c_str(), str_inf.size(),
215 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
216
217 info("Running KWS inference on audio clip %u => %s\n",
218 currentIndex, get_filename(currentIndex));
219
220 /* Start sliding through audio clip. */
221 while (audioDataSlider.HasNext()) {
222 const int16_t* inferenceWindow = audioDataSlider.Next();
223
224 /* We moved to the next window - set the features sliding to the new address. */
225 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
226
227 /* The first window does not have cache ready. */
228 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
229
230 /* Start calculating features inside one audio sliding window. */
231 while (kwsAudioMFCCWindowSlider.HasNext()) {
232 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
233 std::vector<int16_t> kwsMfccAudioData =
234 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
235
236 /* Compute features for this window and write them to input tensor. */
237 kwsMfccFeatureCalc(kwsMfccAudioData,
238 kwsAudioMFCCWindowSlider.Index(),
239 useCache,
240 kwsMfccVectorsInAudioStride);
241 }
242
243 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
244 audioDataSlider.TotalStrides() + 1);
245
246 /* Run inference over this audio clip sliding window. */
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100247 arm::app::RunInference(kwsModel, profiler);
alexander3c798932021-03-26 21:42:19 +0000248
249 std::vector<ClassificationResult> kwsClassificationResult;
250 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
251
252 kwsClassifier.GetClassificationResults(
253 kwsOutputTensor, kwsClassificationResult,
254 ctx.Get<std::vector<std::string>&>("kwslabels"), 1);
255
256 kwsResults.emplace_back(
257 kws::KwsResult(
258 kwsClassificationResult,
259 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
260 audioDataSlider.Index(), kwsScoreThreshold)
261 );
262
263 /* Keyword detected. */
264 if (kwsClassificationResult[0].m_labelIdx == ctx.Get<uint32_t>("keywordindex")) {
265 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
266 output.asrAudioSamples = get_audio_array_size(currentIndex) -
267 (audioDataSlider.NextWindowStartIndex() -
268 kwsAudioDataStride + kwsAudioDataWindowSize);
269 break;
270 }
271
272#if VERIFY_TEST_OUTPUT
273 arm::app::DumpTensor(kwsOutputTensor);
274#endif /* VERIFY_TEST_OUTPUT */
275
276 } /* while (audioDataSlider.HasNext()) */
277
278 /* Erase. */
279 str_inf = std::string(str_inf.size(), ' ');
280 platform.data_psn->present_data_text(
281 str_inf.c_str(), str_inf.size(),
282 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
283
284 if (!_PresentInferenceResult(platform, kwsResults)) {
285 return output;
286 }
287
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100288 profiler.PrintProfilingResult();
289
alexander3c798932021-03-26 21:42:19 +0000290 output.executionSuccess = true;
291 return output;
292 }
293
294 /**
295 * @brief Performs the ASR pipeline.
296 *
297 * @param ctx[in/out] pointer to the application context object
298 * @param kwsOutput[in] struct containing pointer to audio data where ASR should begin
299 * and how much data to process
300 * @return bool true if pipeline executed without failure
301 */
302 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
303 constexpr uint32_t dataPsnTxtInfStartX = 20;
304 constexpr uint32_t dataPsnTxtInfStartY = 40;
305
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100306 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000307 auto& platform = ctx.Get<hal_platform&>("platform");
308 platform.data_psn->clear(COLOR_BLACK);
309
310 /* Get model reference. */
311 auto& asrModel = ctx.Get<Model&>("asrmodel");
312 if (!asrModel.IsInited()) {
313 printf_err("ASR model has not been initialised\n");
314 return false;
315 }
316
317 /* Get score threshold to be applied for the classifier (post-inference). */
318 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
319
320 /* Dimensions of the tensor should have been verified by the callee. */
321 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
322 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
323 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
324
325 /* Populate ASR MFCC related parameters. */
326 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
327 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
328
329 /* Populate ASR inference context and inner lengths for input. */
330 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
331 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
332
333 /* Make sure the input tensor supports the above context and inner lengths. */
334 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
335 printf_err("ASR input rows not compatible with ctx length %u\n", asrInputCtxLen);
336 return false;
337 }
338
339 /* Audio data stride corresponds to inputInnerLen feature vectors. */
340 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
341 asrMfccParamsWinStride + (asrMfccParamsWinLen);
342 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
343 const float asrAudioParamsSecondsPerSample =
344 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
345
346 /* Get pre/post-processing objects */
347 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
348 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
349
350 /* Set default reduction axis for post-processing. */
351 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
352
353 /* Get the remaining audio buffer and respective size from KWS results. */
354 const int16_t* audioArr = kwsOutput.asrAudioStart;
355 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
356
357 /* Audio clip must have enough samples to produce 1 MFCC feature. */
358 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
359 if (audioArrSize < asrMfccParamsWinLen) {
360 printf_err("Not enough audio samples, minimum needed is %u\n", asrMfccParamsWinLen);
361 return false;
362 }
363
364 /* Initialise an audio slider. */
365 auto audioDataSlider = audio::ASRSlidingWindow<const int16_t>(
366 audioBuffer.data(),
367 audioBuffer.size(),
368 asrAudioParamsWinLen,
369 asrAudioParamsWinStride);
370
371 /* Declare a container for results. */
372 std::vector<arm::app::asr::AsrResult> asrResults;
373
374 /* Display message on the LCD - inference running. */
375 std::string str_inf{"Running ASR inference... "};
376 platform.data_psn->present_data_text(
377 str_inf.c_str(), str_inf.size(),
378 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
379
380 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
381
382 /* Start sliding through audio clip. */
383 while (audioDataSlider.HasNext()) {
384
385 /* If not enough audio see how much can be sent for processing. */
386 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
387 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
388 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
389 }
390
391 const int16_t* asrInferenceWindow = audioDataSlider.Next();
392
393 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
394 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
395
alexander3c798932021-03-26 21:42:19 +0000396 /* Calculate MFCCs, deltas and populate the input tensor. */
397 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
398
alexander3c798932021-03-26 21:42:19 +0000399 /* Run inference over this audio clip sliding window. */
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100400 arm::app::RunInference(asrModel, profiler);
alexander3c798932021-03-26 21:42:19 +0000401
402 /* Post-process. */
403 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
404
405 /* Get results. */
406 std::vector<ClassificationResult> asrClassificationResult;
407 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
408 asrClassifier.GetClassificationResults(
409 asrOutputTensor, asrClassificationResult,
410 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
411
412 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
413 (audioDataSlider.Index() *
414 asrAudioParamsSecondsPerSample *
415 asrAudioParamsWinStride),
416 audioDataSlider.Index(), asrScoreThreshold));
417
418#if VERIFY_TEST_OUTPUT
419 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
420#endif /* VERIFY_TEST_OUTPUT */
421
422 /* Erase */
423 str_inf = std::string(str_inf.size(), ' ');
424 platform.data_psn->present_data_text(
425 str_inf.c_str(), str_inf.size(),
426 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
427 }
428 if (!_PresentInferenceResult(platform, asrResults)) {
429 return false;
430 }
431
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100432 profiler.PrintProfilingResult();
433
alexander3c798932021-03-26 21:42:19 +0000434 return true;
435 }
436
437 /* Audio inference classification handler. */
438 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
439 {
440 auto& platform = ctx.Get<hal_platform&>("platform");
441 platform.data_psn->clear(COLOR_BLACK);
442
443 /* If the request has a valid size, set the audio index. */
444 if (clipIndex < NUMBER_OF_FILES) {
445 if (!_SetAppCtxClipIdx(ctx, clipIndex)) {
446 return false;
447 }
448 }
449
450 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
451
452 do {
453 KWSOutput kwsOutput = doKws(ctx);
454 if (!kwsOutput.executionSuccess) {
455 return false;
456 }
457
458 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
459 info("Keyword spotted\n");
460 if(!doAsr(ctx, kwsOutput)) {
461 printf_err("ASR failed");
462 return false;
463 }
464 }
465
466 _IncrementAppCtxClipIdx(ctx);
467
468 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
469
470 return true;
471 }
472
473 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx)
474 {
475 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
476
477 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
478 ctx.Set<uint32_t>("clipIndex", 0);
479 return;
480 }
481 ++curAudioIdx;
482 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
483 }
484
485 static bool _SetAppCtxClipIdx(ApplicationContext& ctx, const uint32_t idx)
486 {
487 if (idx >= NUMBER_OF_FILES) {
488 printf_err("Invalid idx %u (expected less than %u)\n",
489 idx, NUMBER_OF_FILES);
490 return false;
491 }
492 ctx.Set<uint32_t>("clipIndex", idx);
493 return true;
494 }
495
496 static bool _PresentInferenceResult(hal_platform& platform,
497 std::vector<arm::app::kws::KwsResult>& results)
498 {
499 constexpr uint32_t dataPsnTxtStartX1 = 20;
500 constexpr uint32_t dataPsnTxtStartY1 = 30;
501 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
502
503 platform.data_psn->set_text_color(COLOR_GREEN);
504
505 /* Display each result. */
506 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
507
508 for (uint32_t i = 0; i < results.size(); ++i) {
509
510 std::string topKeyword{"<none>"};
511 float score = 0.f;
512
513 if (results[i].m_resultVec.size()) {
514 topKeyword = results[i].m_resultVec[0].m_label;
515 score = results[i].m_resultVec[0].m_normalisedVal;
516 }
517
518 std::string resultStr =
519 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
520 std::string{"s: "} + topKeyword + std::string{" ("} +
521 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
522
523 platform.data_psn->present_data_text(
524 resultStr.c_str(), resultStr.size(),
525 dataPsnTxtStartX1, rowIdx1, 0);
526 rowIdx1 += dataPsnTxtYIncr;
527
528 info("For timestamp: %f (inference #: %u); threshold: %f\n",
529 results[i].m_timeStamp, results[i].m_inferenceNumber,
530 results[i].m_threshold);
531 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
532 info("\t\tlabel @ %u: %s, score: %f\n", j,
533 results[i].m_resultVec[j].m_label.c_str(),
534 results[i].m_resultVec[j].m_normalisedVal);
535 }
536 }
537
538 return true;
539 }
540
541 static bool _PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results)
542 {
543 constexpr uint32_t dataPsnTxtStartX1 = 20;
544 constexpr uint32_t dataPsnTxtStartY1 = 80;
545 constexpr bool allow_multiple_lines = true;
546
547 platform.data_psn->set_text_color(COLOR_GREEN);
548
549 /* Results from multiple inferences should be combined before processing. */
550 std::vector<arm::app::ClassificationResult> combinedResults;
551 for (auto& result : results) {
552 combinedResults.insert(combinedResults.end(),
553 result.m_resultVec.begin(),
554 result.m_resultVec.end());
555 }
556
557 for (auto& result : results) {
558 /* Get the final result string using the decoder. */
559 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
560
561 info("Result for inf %u: %s\n", result.m_inferenceNumber,
562 infResultStr.c_str());
563 }
564
565 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
566
567 platform.data_psn->present_data_text(
568 finalResultStr.c_str(), finalResultStr.size(),
569 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
570
571 info("Final result: %s\n", finalResultStr.c_str());
572 return true;
573 }
574
575 /**
576 * @brief Generic feature calculator factory.
577 *
578 * Returns lambda function to compute features using features cache.
579 * Real features math is done by a lambda function provided as a parameter.
580 * Features are written to input tensor memory.
581 *
582 * @tparam T feature vector type.
583 * @param inputTensor model input tensor pointer.
584 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
585 * @param compute features calculator function.
586 * @return lambda function to compute features.
587 **/
588 template<class T>
589 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
590 _FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
591 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
592 {
593 /* Feature cache to be captured by lambda function. */
594 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
595
596 return [=](std::vector<int16_t>& audioDataWindow,
597 size_t index,
598 bool useCache,
599 size_t featuresOverlapIndex)
600 {
601 T* tensorData = tflite::GetTensorData<T>(inputTensor);
602 std::vector<T> features;
603
604 /* Reuse features from cache if cache is ready and sliding windows overlap.
605 * Overlap is in the beginning of sliding window with a size of a feature cache.
606 */
607 if (useCache && index < featureCache.size()) {
608 features = std::move(featureCache[index]);
609 } else {
610 features = std::move(compute(audioDataWindow));
611 }
612 auto size = features.size();
613 auto sizeBytes = sizeof(T) * size;
614 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
615
616 /* Start renewing cache as soon iteration goes out of the windows overlap. */
617 if (index >= featuresOverlapIndex) {
618 featureCache[index - featuresOverlapIndex] = std::move(features);
619 }
620 };
621 }
622
623 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
624 _FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
625 size_t cacheSize,
626 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
627
628 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
629 _FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
630 size_t cacheSize,
631 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
632
633 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
634 _FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
635 size_t cacheSize,
636 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
637
638 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
639 _FeatureCalc<float>(TfLiteTensor* inputTensor,
640 size_t cacheSize,
641 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
642
643
644 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
645 GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
646 {
647 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
648
649 TfLiteQuantization quant = inputTensor->quantization;
650
651 if (kTfLiteAffineQuantization == quant.type) {
652
653 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
654 const float quantScale = quantParams->scale->data[0];
655 const int quantOffset = quantParams->zero_point->data[0];
656
657 switch (inputTensor->type) {
658 case kTfLiteInt8: {
659 mfccFeatureCalc = _FeatureCalc<int8_t>(inputTensor,
660 cacheSize,
661 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
662 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
663 quantScale,
664 quantOffset);
665 }
666 );
667 break;
668 }
669 case kTfLiteUInt8: {
670 mfccFeatureCalc = _FeatureCalc<uint8_t>(inputTensor,
671 cacheSize,
672 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
673 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
674 quantScale,
675 quantOffset);
676 }
677 );
678 break;
679 }
680 case kTfLiteInt16: {
681 mfccFeatureCalc = _FeatureCalc<int16_t>(inputTensor,
682 cacheSize,
683 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
684 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
685 quantScale,
686 quantOffset);
687 }
688 );
689 break;
690 }
691 default:
692 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
693 }
694
695
696 } else {
697 mfccFeatureCalc = mfccFeatureCalc = _FeatureCalc<float>(inputTensor,
698 cacheSize,
699 [&mfcc](std::vector<int16_t>& audioDataWindow) {
700 return mfcc.MfccCompute(audioDataWindow);
701 });
702 }
703 return mfccFeatureCalc;
704 }
705} /* namespace app */
706} /* namespace arm */