blob: 1edc7c43392ec9a0877680bc5246147b5e192fe4 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "DsCnnModel.hpp"
24#include "DsCnnMfcc.hpp"
25#include "Classifier.hpp"
26#include "KwsResult.hpp"
27#include "Wav2LetterMfcc.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "Wav2LetterPostprocess.hpp"
30#include "AsrResult.hpp"
31#include "AsrClassifier.hpp"
32#include "OutputDecode.hpp"
33
34
35using KwsClassifier = arm::app::Classifier;
36
37namespace arm {
38namespace app {
39
40 enum AsrOutputReductionAxis {
41 AxisRow = 1,
42 AxisCol = 2
43 };
44
45 struct KWSOutput {
46 bool executionSuccess = false;
47 const int16_t* asrAudioStart = nullptr;
48 int32_t asrAudioSamples = 0;
49 };
50
51 /**
52 * @brief Helper function to increment current audio clip index
53 * @param[in,out] ctx pointer to the application context object
54 **/
alexanderc350cdc2021-04-29 20:36:09 +010055 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
alexander3c798932021-03-26 21:42:19 +000056
57 /**
58 * @brief Helper function to set the audio clip index
59 * @param[in,out] ctx pointer to the application context object
60 * @param[in] idx value to be set
61 * @return true if index is set, false otherwise
62 **/
alexanderc350cdc2021-04-29 20:36:09 +010063 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
alexander3c798932021-03-26 21:42:19 +000064
65 /**
66 * @brief Presents kws inference results using the data presentation
67 * object.
68 * @param[in] platform reference to the hal platform object
69 * @param[in] results vector of classification results to be displayed
70 * @param[in] infTimeMs inference time in milliseconds, if available
71 * Otherwise, this can be passed in as 0.
72 * @return true if successful, false otherwise
73 **/
alexanderc350cdc2021-04-29 20:36:09 +010074 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000075
76 /**
77 * @brief Presents asr inference results using the data presentation
78 * object.
79 * @param[in] platform reference to the hal platform object
80 * @param[in] results vector of classification results to be displayed
81 * @param[in] infTimeMs inference time in milliseconds, if available
82 * Otherwise, this can be passed in as 0.
83 * @return true if successful, false otherwise
84 **/
alexanderc350cdc2021-04-29 20:36:09 +010085 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000086
87 /**
88 * @brief Returns a function to perform feature calculation and populates input tensor data with
89 * MFCC data.
90 *
91 * Input tensor data type check is performed to choose correct MFCC feature data type.
92 * If tensor has an integer data type then original features are quantised.
93 *
94 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
95 *
96 * @param[in] mfcc MFCC feature calculator.
97 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
98 * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
99 *
100 * @return function function to be called providing audio sample and sliding window index.
101 **/
102 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
103 GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
104 TfLiteTensor* inputTensor,
105 size_t cacheSize);
106
107 /**
108 * @brief Performs the KWS pipeline.
109 * @param[in,out] ctx pointer to the application context object
110 *
111 * @return KWSOutput struct containing pointer to audio data where ASR should begin
112 * and how much data to process.
113 */
114 static KWSOutput doKws(ApplicationContext& ctx) {
115 constexpr uint32_t dataPsnTxtInfStartX = 20;
116 constexpr uint32_t dataPsnTxtInfStartY = 40;
117
118 constexpr int minTensorDims = static_cast<int>(
119 (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
120 arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
121
122 KWSOutput output;
123
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100124 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000125 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
126 if (!kwsModel.IsInited()) {
127 printf_err("KWS model has not been initialised\n");
128 return output;
129 }
130
131 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
132 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
133 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
134
135 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
136 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
137
138 if (!kwsInputTensor->dims) {
139 printf_err("Invalid input tensor dims\n");
140 return output;
141 } else if (kwsInputTensor->dims->size < minTensorDims) {
142 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
143 return output;
144 }
145
146 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
147 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
148
149 audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength);
150 kwsMfcc.Init();
151
152 /* Deduce the data length required for 1 KWS inference from the network parameters. */
153 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
154 (kwsFrameLength - kwsFrameStride);
155 auto kwsMfccWindowSize = kwsFrameLength;
156 auto kwsMfccWindowStride = kwsFrameStride;
157
158 /* We are choosing to move by half the window size => for a 1 second window size,
159 * this means an overlap of 0.5 seconds. */
160 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
161
162 info("KWS audio data window size %u\n", kwsAudioDataWindowSize);
163
164 /* Stride must be multiple of mfcc features window stride to re-use features. */
165 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
166 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
167 }
168
169 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
170
171 /* We expect to be sampling 1 second worth of data at a time
172 * NOTE: This is only used for time stamp calculation. */
173 const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
174
175 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
176
177 /* Creating a mfcc features sliding window for the data required for 1 inference. */
178 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
179 get_audio_array(currentIndex),
180 kwsAudioDataWindowSize, kwsMfccWindowSize,
181 kwsMfccWindowStride);
182
183 /* Creating a sliding window through the whole audio clip. */
184 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
185 get_audio_array(currentIndex),
186 get_audio_array_size(currentIndex),
187 kwsAudioDataWindowSize, kwsAudioDataStride);
188
189 /* Calculate number of the feature vectors in the window overlap region.
190 * These feature vectors will be reused.*/
191 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
192 - kwsMfccVectorsInAudioStride;
193
194 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
195 numberOfReusedFeatureVectors);
196
197 if (!kwsMfccFeatureCalc){
198 return output;
199 }
200
201 /* Container for KWS results. */
202 std::vector<arm::app::kws::KwsResult> kwsResults;
203
204 /* Display message on the LCD - inference running. */
205 auto& platform = ctx.Get<hal_platform&>("platform");
206 std::string str_inf{"Running KWS inference... "};
207 platform.data_psn->present_data_text(
208 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100209 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000210
211 info("Running KWS inference on audio clip %u => %s\n",
212 currentIndex, get_filename(currentIndex));
213
214 /* Start sliding through audio clip. */
215 while (audioDataSlider.HasNext()) {
216 const int16_t* inferenceWindow = audioDataSlider.Next();
217
218 /* We moved to the next window - set the features sliding to the new address. */
219 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
220
221 /* The first window does not have cache ready. */
222 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
223
224 /* Start calculating features inside one audio sliding window. */
225 while (kwsAudioMFCCWindowSlider.HasNext()) {
226 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
227 std::vector<int16_t> kwsMfccAudioData =
228 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
229
230 /* Compute features for this window and write them to input tensor. */
231 kwsMfccFeatureCalc(kwsMfccAudioData,
232 kwsAudioMFCCWindowSlider.Index(),
233 useCache,
234 kwsMfccVectorsInAudioStride);
235 }
236
237 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
238 audioDataSlider.TotalStrides() + 1);
239
240 /* Run inference over this audio clip sliding window. */
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100241 arm::app::RunInference(kwsModel, profiler);
alexander3c798932021-03-26 21:42:19 +0000242
243 std::vector<ClassificationResult> kwsClassificationResult;
244 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
245
246 kwsClassifier.GetClassificationResults(
247 kwsOutputTensor, kwsClassificationResult,
248 ctx.Get<std::vector<std::string>&>("kwslabels"), 1);
249
250 kwsResults.emplace_back(
251 kws::KwsResult(
252 kwsClassificationResult,
253 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
254 audioDataSlider.Index(), kwsScoreThreshold)
255 );
256
257 /* Keyword detected. */
258 if (kwsClassificationResult[0].m_labelIdx == ctx.Get<uint32_t>("keywordindex")) {
259 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
260 output.asrAudioSamples = get_audio_array_size(currentIndex) -
261 (audioDataSlider.NextWindowStartIndex() -
262 kwsAudioDataStride + kwsAudioDataWindowSize);
263 break;
264 }
265
266#if VERIFY_TEST_OUTPUT
267 arm::app::DumpTensor(kwsOutputTensor);
268#endif /* VERIFY_TEST_OUTPUT */
269
270 } /* while (audioDataSlider.HasNext()) */
271
272 /* Erase. */
273 str_inf = std::string(str_inf.size(), ' ');
274 platform.data_psn->present_data_text(
275 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100276 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000277
alexanderc350cdc2021-04-29 20:36:09 +0100278 if (!PresentInferenceResult(platform, kwsResults)) {
alexander3c798932021-03-26 21:42:19 +0000279 return output;
280 }
281
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100282 profiler.PrintProfilingResult();
283
alexander3c798932021-03-26 21:42:19 +0000284 output.executionSuccess = true;
285 return output;
286 }
287
288 /**
289 * @brief Performs the ASR pipeline.
290 *
291 * @param ctx[in/out] pointer to the application context object
292 * @param kwsOutput[in] struct containing pointer to audio data where ASR should begin
293 * and how much data to process
294 * @return bool true if pipeline executed without failure
295 */
296 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
297 constexpr uint32_t dataPsnTxtInfStartX = 20;
298 constexpr uint32_t dataPsnTxtInfStartY = 40;
299
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100300 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000301 auto& platform = ctx.Get<hal_platform&>("platform");
302 platform.data_psn->clear(COLOR_BLACK);
303
304 /* Get model reference. */
305 auto& asrModel = ctx.Get<Model&>("asrmodel");
306 if (!asrModel.IsInited()) {
307 printf_err("ASR model has not been initialised\n");
308 return false;
309 }
310
311 /* Get score threshold to be applied for the classifier (post-inference). */
312 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
313
314 /* Dimensions of the tensor should have been verified by the callee. */
315 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
316 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
317 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
318
319 /* Populate ASR MFCC related parameters. */
320 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
321 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
322
323 /* Populate ASR inference context and inner lengths for input. */
324 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
325 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
326
327 /* Make sure the input tensor supports the above context and inner lengths. */
328 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
329 printf_err("ASR input rows not compatible with ctx length %u\n", asrInputCtxLen);
330 return false;
331 }
332
333 /* Audio data stride corresponds to inputInnerLen feature vectors. */
334 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
335 asrMfccParamsWinStride + (asrMfccParamsWinLen);
336 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
337 const float asrAudioParamsSecondsPerSample =
338 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
339
340 /* Get pre/post-processing objects */
341 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
342 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
343
344 /* Set default reduction axis for post-processing. */
345 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
346
347 /* Get the remaining audio buffer and respective size from KWS results. */
348 const int16_t* audioArr = kwsOutput.asrAudioStart;
349 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
350
351 /* Audio clip must have enough samples to produce 1 MFCC feature. */
352 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
353 if (audioArrSize < asrMfccParamsWinLen) {
354 printf_err("Not enough audio samples, minimum needed is %u\n", asrMfccParamsWinLen);
355 return false;
356 }
357
358 /* Initialise an audio slider. */
359 auto audioDataSlider = audio::ASRSlidingWindow<const int16_t>(
360 audioBuffer.data(),
361 audioBuffer.size(),
362 asrAudioParamsWinLen,
363 asrAudioParamsWinStride);
364
365 /* Declare a container for results. */
366 std::vector<arm::app::asr::AsrResult> asrResults;
367
368 /* Display message on the LCD - inference running. */
369 std::string str_inf{"Running ASR inference... "};
370 platform.data_psn->present_data_text(
371 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100372 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000373
374 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
375
376 /* Start sliding through audio clip. */
377 while (audioDataSlider.HasNext()) {
378
379 /* If not enough audio see how much can be sent for processing. */
380 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
381 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
382 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
383 }
384
385 const int16_t* asrInferenceWindow = audioDataSlider.Next();
386
387 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
388 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
389
alexander3c798932021-03-26 21:42:19 +0000390 /* Calculate MFCCs, deltas and populate the input tensor. */
391 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
392
alexander3c798932021-03-26 21:42:19 +0000393 /* Run inference over this audio clip sliding window. */
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100394 arm::app::RunInference(asrModel, profiler);
alexander3c798932021-03-26 21:42:19 +0000395
396 /* Post-process. */
397 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
398
399 /* Get results. */
400 std::vector<ClassificationResult> asrClassificationResult;
401 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
402 asrClassifier.GetClassificationResults(
403 asrOutputTensor, asrClassificationResult,
404 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
405
406 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
407 (audioDataSlider.Index() *
408 asrAudioParamsSecondsPerSample *
409 asrAudioParamsWinStride),
410 audioDataSlider.Index(), asrScoreThreshold));
411
412#if VERIFY_TEST_OUTPUT
413 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
414#endif /* VERIFY_TEST_OUTPUT */
415
416 /* Erase */
417 str_inf = std::string(str_inf.size(), ' ');
418 platform.data_psn->present_data_text(
419 str_inf.c_str(), str_inf.size(),
420 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
421 }
alexanderc350cdc2021-04-29 20:36:09 +0100422 if (!PresentInferenceResult(platform, asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000423 return false;
424 }
425
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100426 profiler.PrintProfilingResult();
427
alexander3c798932021-03-26 21:42:19 +0000428 return true;
429 }
430
431 /* Audio inference classification handler. */
432 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
433 {
434 auto& platform = ctx.Get<hal_platform&>("platform");
435 platform.data_psn->clear(COLOR_BLACK);
436
437 /* If the request has a valid size, set the audio index. */
438 if (clipIndex < NUMBER_OF_FILES) {
alexanderc350cdc2021-04-29 20:36:09 +0100439 if (!SetAppCtxClipIdx(ctx, clipIndex)) {
alexander3c798932021-03-26 21:42:19 +0000440 return false;
441 }
442 }
443
444 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
445
446 do {
447 KWSOutput kwsOutput = doKws(ctx);
448 if (!kwsOutput.executionSuccess) {
449 return false;
450 }
451
452 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
453 info("Keyword spotted\n");
454 if(!doAsr(ctx, kwsOutput)) {
455 printf_err("ASR failed");
456 return false;
457 }
458 }
459
alexanderc350cdc2021-04-29 20:36:09 +0100460 IncrementAppCtxClipIdx(ctx);
alexander3c798932021-03-26 21:42:19 +0000461
462 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
463
464 return true;
465 }
466
alexanderc350cdc2021-04-29 20:36:09 +0100467 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
alexander3c798932021-03-26 21:42:19 +0000468 {
469 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
470
471 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
472 ctx.Set<uint32_t>("clipIndex", 0);
473 return;
474 }
475 ++curAudioIdx;
476 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
477 }
478
alexanderc350cdc2021-04-29 20:36:09 +0100479 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx)
alexander3c798932021-03-26 21:42:19 +0000480 {
481 if (idx >= NUMBER_OF_FILES) {
482 printf_err("Invalid idx %u (expected less than %u)\n",
483 idx, NUMBER_OF_FILES);
484 return false;
485 }
486 ctx.Set<uint32_t>("clipIndex", idx);
487 return true;
488 }
489
alexanderc350cdc2021-04-29 20:36:09 +0100490 static bool PresentInferenceResult(hal_platform& platform,
491 std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000492 {
493 constexpr uint32_t dataPsnTxtStartX1 = 20;
494 constexpr uint32_t dataPsnTxtStartY1 = 30;
495 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
496
497 platform.data_psn->set_text_color(COLOR_GREEN);
498
499 /* Display each result. */
500 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
501
502 for (uint32_t i = 0; i < results.size(); ++i) {
503
504 std::string topKeyword{"<none>"};
505 float score = 0.f;
506
alexanderc350cdc2021-04-29 20:36:09 +0100507 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000508 topKeyword = results[i].m_resultVec[0].m_label;
509 score = results[i].m_resultVec[0].m_normalisedVal;
510 }
511
512 std::string resultStr =
513 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
514 std::string{"s: "} + topKeyword + std::string{" ("} +
515 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
516
517 platform.data_psn->present_data_text(
518 resultStr.c_str(), resultStr.size(),
519 dataPsnTxtStartX1, rowIdx1, 0);
520 rowIdx1 += dataPsnTxtYIncr;
521
522 info("For timestamp: %f (inference #: %u); threshold: %f\n",
523 results[i].m_timeStamp, results[i].m_inferenceNumber,
524 results[i].m_threshold);
525 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
526 info("\t\tlabel @ %u: %s, score: %f\n", j,
527 results[i].m_resultVec[j].m_label.c_str(),
528 results[i].m_resultVec[j].m_normalisedVal);
529 }
530 }
531
532 return true;
533 }
534
alexanderc350cdc2021-04-29 20:36:09 +0100535 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000536 {
537 constexpr uint32_t dataPsnTxtStartX1 = 20;
538 constexpr uint32_t dataPsnTxtStartY1 = 80;
539 constexpr bool allow_multiple_lines = true;
540
541 platform.data_psn->set_text_color(COLOR_GREEN);
542
543 /* Results from multiple inferences should be combined before processing. */
544 std::vector<arm::app::ClassificationResult> combinedResults;
545 for (auto& result : results) {
546 combinedResults.insert(combinedResults.end(),
547 result.m_resultVec.begin(),
548 result.m_resultVec.end());
549 }
550
551 for (auto& result : results) {
552 /* Get the final result string using the decoder. */
553 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
554
555 info("Result for inf %u: %s\n", result.m_inferenceNumber,
556 infResultStr.c_str());
557 }
558
559 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
560
561 platform.data_psn->present_data_text(
562 finalResultStr.c_str(), finalResultStr.size(),
563 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
564
565 info("Final result: %s\n", finalResultStr.c_str());
566 return true;
567 }
568
569 /**
570 * @brief Generic feature calculator factory.
571 *
572 * Returns lambda function to compute features using features cache.
573 * Real features math is done by a lambda function provided as a parameter.
574 * Features are written to input tensor memory.
575 *
576 * @tparam T feature vector type.
577 * @param inputTensor model input tensor pointer.
578 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
579 * @param compute features calculator function.
580 * @return lambda function to compute features.
581 **/
582 template<class T>
583 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100584 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
585 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000586 {
587 /* Feature cache to be captured by lambda function. */
588 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
589
590 return [=](std::vector<int16_t>& audioDataWindow,
591 size_t index,
592 bool useCache,
593 size_t featuresOverlapIndex)
594 {
595 T* tensorData = tflite::GetTensorData<T>(inputTensor);
596 std::vector<T> features;
597
598 /* Reuse features from cache if cache is ready and sliding windows overlap.
599 * Overlap is in the beginning of sliding window with a size of a feature cache.
600 */
601 if (useCache && index < featureCache.size()) {
602 features = std::move(featureCache[index]);
603 } else {
604 features = std::move(compute(audioDataWindow));
605 }
606 auto size = features.size();
607 auto sizeBytes = sizeof(T) * size;
608 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
609
610 /* Start renewing cache as soon iteration goes out of the windows overlap. */
611 if (index >= featuresOverlapIndex) {
612 featureCache[index - featuresOverlapIndex] = std::move(features);
613 }
614 };
615 }
616
617 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100618 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
619 size_t cacheSize,
620 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
621
622 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
623 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000624 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100625 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000626
627 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100628 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
629 size_t cacheSize,
630 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000631
632 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100633 FeatureCalc<float>(TfLiteTensor* inputTensor,
634 size_t cacheSize,
635 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000636
637
638 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
639 GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
640 {
641 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
642
643 TfLiteQuantization quant = inputTensor->quantization;
644
645 if (kTfLiteAffineQuantization == quant.type) {
646
647 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
648 const float quantScale = quantParams->scale->data[0];
649 const int quantOffset = quantParams->zero_point->data[0];
650
651 switch (inputTensor->type) {
652 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100653 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
654 cacheSize,
655 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
656 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
657 quantScale,
658 quantOffset);
659 }
alexander3c798932021-03-26 21:42:19 +0000660 );
661 break;
662 }
663 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100664 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
665 cacheSize,
666 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
667 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
668 quantScale,
669 quantOffset);
670 }
alexander3c798932021-03-26 21:42:19 +0000671 );
672 break;
673 }
674 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100675 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
676 cacheSize,
677 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
678 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
679 quantScale,
680 quantOffset);
681 }
alexander3c798932021-03-26 21:42:19 +0000682 );
683 break;
684 }
685 default:
686 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
687 }
688
689
690 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100691 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
692 cacheSize,
693 [&mfcc](std::vector<int16_t>& audioDataWindow) {
694 return mfcc.MfccCompute(audioDataWindow);
695 });
alexander3c798932021-03-26 21:42:19 +0000696 }
697 return mfccFeatureCalc;
698 }
699} /* namespace app */
700} /* namespace arm */