blob: 0560e888dce0f3a6dda60d68ef90052cd6907332 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "DsCnnModel.hpp"
24#include "DsCnnMfcc.hpp"
25#include "Classifier.hpp"
26#include "KwsResult.hpp"
27#include "Wav2LetterMfcc.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "Wav2LetterPostprocess.hpp"
30#include "AsrResult.hpp"
31#include "AsrClassifier.hpp"
32#include "OutputDecode.hpp"
33
34
35using KwsClassifier = arm::app::Classifier;
36
37namespace arm {
38namespace app {
39
40 enum AsrOutputReductionAxis {
41 AxisRow = 1,
42 AxisCol = 2
43 };
44
45 struct KWSOutput {
46 bool executionSuccess = false;
47 const int16_t* asrAudioStart = nullptr;
48 int32_t asrAudioSamples = 0;
49 };
50
51 /**
52 * @brief Helper function to increment current audio clip index
53 * @param[in,out] ctx pointer to the application context object
54 **/
alexanderc350cdc2021-04-29 20:36:09 +010055 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
alexander3c798932021-03-26 21:42:19 +000056
57 /**
58 * @brief Helper function to set the audio clip index
59 * @param[in,out] ctx pointer to the application context object
60 * @param[in] idx value to be set
61 * @return true if index is set, false otherwise
62 **/
alexanderc350cdc2021-04-29 20:36:09 +010063 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
alexander3c798932021-03-26 21:42:19 +000064
65 /**
66 * @brief Presents kws inference results using the data presentation
67 * object.
68 * @param[in] platform reference to the hal platform object
69 * @param[in] results vector of classification results to be displayed
70 * @param[in] infTimeMs inference time in milliseconds, if available
71 * Otherwise, this can be passed in as 0.
72 * @return true if successful, false otherwise
73 **/
alexanderc350cdc2021-04-29 20:36:09 +010074 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000075
76 /**
77 * @brief Presents asr inference results using the data presentation
78 * object.
79 * @param[in] platform reference to the hal platform object
80 * @param[in] results vector of classification results to be displayed
81 * @param[in] infTimeMs inference time in milliseconds, if available
82 * Otherwise, this can be passed in as 0.
83 * @return true if successful, false otherwise
84 **/
alexanderc350cdc2021-04-29 20:36:09 +010085 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000086
87 /**
88 * @brief Returns a function to perform feature calculation and populates input tensor data with
89 * MFCC data.
90 *
91 * Input tensor data type check is performed to choose correct MFCC feature data type.
92 * If tensor has an integer data type then original features are quantised.
93 *
94 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
95 *
96 * @param[in] mfcc MFCC feature calculator.
97 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
98 * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
99 *
100 * @return function function to be called providing audio sample and sliding window index.
101 **/
102 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
103 GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
104 TfLiteTensor* inputTensor,
105 size_t cacheSize);
106
107 /**
108 * @brief Performs the KWS pipeline.
109 * @param[in,out] ctx pointer to the application context object
110 *
111 * @return KWSOutput struct containing pointer to audio data where ASR should begin
112 * and how much data to process.
113 */
114 static KWSOutput doKws(ApplicationContext& ctx) {
115 constexpr uint32_t dataPsnTxtInfStartX = 20;
116 constexpr uint32_t dataPsnTxtInfStartY = 40;
117
118 constexpr int minTensorDims = static_cast<int>(
119 (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
120 arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
121
122 KWSOutput output;
123
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100124 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000125 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
126 if (!kwsModel.IsInited()) {
127 printf_err("KWS model has not been initialised\n");
128 return output;
129 }
130
131 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
132 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
133 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
134
135 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
136 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
137
138 if (!kwsInputTensor->dims) {
139 printf_err("Invalid input tensor dims\n");
140 return output;
141 } else if (kwsInputTensor->dims->size < minTensorDims) {
142 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
143 return output;
144 }
145
146 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
147 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
148
149 audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength);
150 kwsMfcc.Init();
151
152 /* Deduce the data length required for 1 KWS inference from the network parameters. */
153 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
154 (kwsFrameLength - kwsFrameStride);
155 auto kwsMfccWindowSize = kwsFrameLength;
156 auto kwsMfccWindowStride = kwsFrameStride;
157
158 /* We are choosing to move by half the window size => for a 1 second window size,
159 * this means an overlap of 0.5 seconds. */
160 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
161
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100162 info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000163
164 /* Stride must be multiple of mfcc features window stride to re-use features. */
165 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
166 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
167 }
168
169 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
170
171 /* We expect to be sampling 1 second worth of data at a time
172 * NOTE: This is only used for time stamp calculation. */
173 const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
174
175 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
176
177 /* Creating a mfcc features sliding window for the data required for 1 inference. */
178 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
179 get_audio_array(currentIndex),
180 kwsAudioDataWindowSize, kwsMfccWindowSize,
181 kwsMfccWindowStride);
182
183 /* Creating a sliding window through the whole audio clip. */
184 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
185 get_audio_array(currentIndex),
186 get_audio_array_size(currentIndex),
187 kwsAudioDataWindowSize, kwsAudioDataStride);
188
189 /* Calculate number of the feature vectors in the window overlap region.
190 * These feature vectors will be reused.*/
191 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
192 - kwsMfccVectorsInAudioStride;
193
194 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
195 numberOfReusedFeatureVectors);
196
197 if (!kwsMfccFeatureCalc){
198 return output;
199 }
200
201 /* Container for KWS results. */
202 std::vector<arm::app::kws::KwsResult> kwsResults;
203
204 /* Display message on the LCD - inference running. */
205 auto& platform = ctx.Get<hal_platform&>("platform");
206 std::string str_inf{"Running KWS inference... "};
207 platform.data_psn->present_data_text(
208 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100209 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000210
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100211 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
alexander3c798932021-03-26 21:42:19 +0000212 currentIndex, get_filename(currentIndex));
213
214 /* Start sliding through audio clip. */
215 while (audioDataSlider.HasNext()) {
216 const int16_t* inferenceWindow = audioDataSlider.Next();
217
218 /* We moved to the next window - set the features sliding to the new address. */
219 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
220
221 /* The first window does not have cache ready. */
222 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
223
224 /* Start calculating features inside one audio sliding window. */
225 while (kwsAudioMFCCWindowSlider.HasNext()) {
226 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
227 std::vector<int16_t> kwsMfccAudioData =
228 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
229
230 /* Compute features for this window and write them to input tensor. */
231 kwsMfccFeatureCalc(kwsMfccAudioData,
232 kwsAudioMFCCWindowSlider.Index(),
233 useCache,
234 kwsMfccVectorsInAudioStride);
235 }
236
237 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
238 audioDataSlider.TotalStrides() + 1);
239
240 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100241 if (!RunInference(kwsModel, profiler)) {
242 printf_err("KWS inference failed\n");
243 return output;
244 }
alexander3c798932021-03-26 21:42:19 +0000245
246 std::vector<ClassificationResult> kwsClassificationResult;
247 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
248
249 kwsClassifier.GetClassificationResults(
250 kwsOutputTensor, kwsClassificationResult,
251 ctx.Get<std::vector<std::string>&>("kwslabels"), 1);
252
253 kwsResults.emplace_back(
254 kws::KwsResult(
255 kwsClassificationResult,
256 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
257 audioDataSlider.Index(), kwsScoreThreshold)
258 );
259
260 /* Keyword detected. */
261 if (kwsClassificationResult[0].m_labelIdx == ctx.Get<uint32_t>("keywordindex")) {
262 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
263 output.asrAudioSamples = get_audio_array_size(currentIndex) -
264 (audioDataSlider.NextWindowStartIndex() -
265 kwsAudioDataStride + kwsAudioDataWindowSize);
266 break;
267 }
268
269#if VERIFY_TEST_OUTPUT
270 arm::app::DumpTensor(kwsOutputTensor);
271#endif /* VERIFY_TEST_OUTPUT */
272
273 } /* while (audioDataSlider.HasNext()) */
274
275 /* Erase. */
276 str_inf = std::string(str_inf.size(), ' ');
277 platform.data_psn->present_data_text(
278 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100279 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000280
alexanderc350cdc2021-04-29 20:36:09 +0100281 if (!PresentInferenceResult(platform, kwsResults)) {
alexander3c798932021-03-26 21:42:19 +0000282 return output;
283 }
284
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100285 profiler.PrintProfilingResult();
286
alexander3c798932021-03-26 21:42:19 +0000287 output.executionSuccess = true;
288 return output;
289 }
290
291 /**
292 * @brief Performs the ASR pipeline.
293 *
294 * @param ctx[in/out] pointer to the application context object
295 * @param kwsOutput[in] struct containing pointer to audio data where ASR should begin
296 * and how much data to process
297 * @return bool true if pipeline executed without failure
298 */
299 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
300 constexpr uint32_t dataPsnTxtInfStartX = 20;
301 constexpr uint32_t dataPsnTxtInfStartY = 40;
302
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100303 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000304 auto& platform = ctx.Get<hal_platform&>("platform");
305 platform.data_psn->clear(COLOR_BLACK);
306
307 /* Get model reference. */
308 auto& asrModel = ctx.Get<Model&>("asrmodel");
309 if (!asrModel.IsInited()) {
310 printf_err("ASR model has not been initialised\n");
311 return false;
312 }
313
314 /* Get score threshold to be applied for the classifier (post-inference). */
315 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
316
317 /* Dimensions of the tensor should have been verified by the callee. */
318 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
319 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
320 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
321
322 /* Populate ASR MFCC related parameters. */
323 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
324 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
325
326 /* Populate ASR inference context and inner lengths for input. */
327 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
328 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
329
330 /* Make sure the input tensor supports the above context and inner lengths. */
331 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100332 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
333 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000334 return false;
335 }
336
337 /* Audio data stride corresponds to inputInnerLen feature vectors. */
338 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
339 asrMfccParamsWinStride + (asrMfccParamsWinLen);
340 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
341 const float asrAudioParamsSecondsPerSample =
342 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
343
344 /* Get pre/post-processing objects */
345 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
346 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
347
348 /* Set default reduction axis for post-processing. */
349 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
350
351 /* Get the remaining audio buffer and respective size from KWS results. */
352 const int16_t* audioArr = kwsOutput.asrAudioStart;
353 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
354
355 /* Audio clip must have enough samples to produce 1 MFCC feature. */
356 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
357 if (audioArrSize < asrMfccParamsWinLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100358 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
359 asrMfccParamsWinLen);
alexander3c798932021-03-26 21:42:19 +0000360 return false;
361 }
362
363 /* Initialise an audio slider. */
364 auto audioDataSlider = audio::ASRSlidingWindow<const int16_t>(
365 audioBuffer.data(),
366 audioBuffer.size(),
367 asrAudioParamsWinLen,
368 asrAudioParamsWinStride);
369
370 /* Declare a container for results. */
371 std::vector<arm::app::asr::AsrResult> asrResults;
372
373 /* Display message on the LCD - inference running. */
374 std::string str_inf{"Running ASR inference... "};
375 platform.data_psn->present_data_text(
376 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100377 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000378
379 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
380
381 /* Start sliding through audio clip. */
382 while (audioDataSlider.HasNext()) {
383
384 /* If not enough audio see how much can be sent for processing. */
385 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
386 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
387 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
388 }
389
390 const int16_t* asrInferenceWindow = audioDataSlider.Next();
391
392 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
393 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
394
alexander3c798932021-03-26 21:42:19 +0000395 /* Calculate MFCCs, deltas and populate the input tensor. */
396 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
397
alexander3c798932021-03-26 21:42:19 +0000398 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100399 if (!RunInference(asrModel, profiler)) {
400 printf_err("ASR inference failed\n");
401 return false;
402 }
alexander3c798932021-03-26 21:42:19 +0000403
404 /* Post-process. */
405 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
406
407 /* Get results. */
408 std::vector<ClassificationResult> asrClassificationResult;
409 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
410 asrClassifier.GetClassificationResults(
411 asrOutputTensor, asrClassificationResult,
412 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
413
414 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
415 (audioDataSlider.Index() *
416 asrAudioParamsSecondsPerSample *
417 asrAudioParamsWinStride),
418 audioDataSlider.Index(), asrScoreThreshold));
419
420#if VERIFY_TEST_OUTPUT
421 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
422#endif /* VERIFY_TEST_OUTPUT */
423
424 /* Erase */
425 str_inf = std::string(str_inf.size(), ' ');
426 platform.data_psn->present_data_text(
427 str_inf.c_str(), str_inf.size(),
428 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
429 }
alexanderc350cdc2021-04-29 20:36:09 +0100430 if (!PresentInferenceResult(platform, asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000431 return false;
432 }
433
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100434 profiler.PrintProfilingResult();
435
alexander3c798932021-03-26 21:42:19 +0000436 return true;
437 }
438
439 /* Audio inference classification handler. */
440 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
441 {
442 auto& platform = ctx.Get<hal_platform&>("platform");
443 platform.data_psn->clear(COLOR_BLACK);
444
445 /* If the request has a valid size, set the audio index. */
446 if (clipIndex < NUMBER_OF_FILES) {
alexanderc350cdc2021-04-29 20:36:09 +0100447 if (!SetAppCtxClipIdx(ctx, clipIndex)) {
alexander3c798932021-03-26 21:42:19 +0000448 return false;
449 }
450 }
451
452 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
453
454 do {
455 KWSOutput kwsOutput = doKws(ctx);
456 if (!kwsOutput.executionSuccess) {
457 return false;
458 }
459
460 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
461 info("Keyword spotted\n");
462 if(!doAsr(ctx, kwsOutput)) {
463 printf_err("ASR failed");
464 return false;
465 }
466 }
467
alexanderc350cdc2021-04-29 20:36:09 +0100468 IncrementAppCtxClipIdx(ctx);
alexander3c798932021-03-26 21:42:19 +0000469
470 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
471
472 return true;
473 }
474
alexanderc350cdc2021-04-29 20:36:09 +0100475 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
alexander3c798932021-03-26 21:42:19 +0000476 {
477 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
478
479 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
480 ctx.Set<uint32_t>("clipIndex", 0);
481 return;
482 }
483 ++curAudioIdx;
484 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
485 }
486
alexanderc350cdc2021-04-29 20:36:09 +0100487 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx)
alexander3c798932021-03-26 21:42:19 +0000488 {
489 if (idx >= NUMBER_OF_FILES) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100490 printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n",
alexander3c798932021-03-26 21:42:19 +0000491 idx, NUMBER_OF_FILES);
492 return false;
493 }
494 ctx.Set<uint32_t>("clipIndex", idx);
495 return true;
496 }
497
alexanderc350cdc2021-04-29 20:36:09 +0100498 static bool PresentInferenceResult(hal_platform& platform,
499 std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000500 {
501 constexpr uint32_t dataPsnTxtStartX1 = 20;
502 constexpr uint32_t dataPsnTxtStartY1 = 30;
503 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
504
505 platform.data_psn->set_text_color(COLOR_GREEN);
506
507 /* Display each result. */
508 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
509
510 for (uint32_t i = 0; i < results.size(); ++i) {
511
512 std::string topKeyword{"<none>"};
513 float score = 0.f;
514
alexanderc350cdc2021-04-29 20:36:09 +0100515 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000516 topKeyword = results[i].m_resultVec[0].m_label;
517 score = results[i].m_resultVec[0].m_normalisedVal;
518 }
519
520 std::string resultStr =
521 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
522 std::string{"s: "} + topKeyword + std::string{" ("} +
523 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
524
525 platform.data_psn->present_data_text(
526 resultStr.c_str(), resultStr.size(),
527 dataPsnTxtStartX1, rowIdx1, 0);
528 rowIdx1 += dataPsnTxtYIncr;
529
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100530 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
alexander3c798932021-03-26 21:42:19 +0000531 results[i].m_timeStamp, results[i].m_inferenceNumber,
532 results[i].m_threshold);
533 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100534 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
alexander3c798932021-03-26 21:42:19 +0000535 results[i].m_resultVec[j].m_label.c_str(),
536 results[i].m_resultVec[j].m_normalisedVal);
537 }
538 }
539
540 return true;
541 }
542
alexanderc350cdc2021-04-29 20:36:09 +0100543 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000544 {
545 constexpr uint32_t dataPsnTxtStartX1 = 20;
546 constexpr uint32_t dataPsnTxtStartY1 = 80;
547 constexpr bool allow_multiple_lines = true;
548
549 platform.data_psn->set_text_color(COLOR_GREEN);
550
551 /* Results from multiple inferences should be combined before processing. */
552 std::vector<arm::app::ClassificationResult> combinedResults;
553 for (auto& result : results) {
554 combinedResults.insert(combinedResults.end(),
555 result.m_resultVec.begin(),
556 result.m_resultVec.end());
557 }
558
559 for (auto& result : results) {
560 /* Get the final result string using the decoder. */
561 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
562
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100563 info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber,
alexander3c798932021-03-26 21:42:19 +0000564 infResultStr.c_str());
565 }
566
567 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
568
569 platform.data_psn->present_data_text(
570 finalResultStr.c_str(), finalResultStr.size(),
571 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
572
573 info("Final result: %s\n", finalResultStr.c_str());
574 return true;
575 }
576
577 /**
578 * @brief Generic feature calculator factory.
579 *
580 * Returns lambda function to compute features using features cache.
581 * Real features math is done by a lambda function provided as a parameter.
582 * Features are written to input tensor memory.
583 *
584 * @tparam T feature vector type.
585 * @param inputTensor model input tensor pointer.
586 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
587 * @param compute features calculator function.
588 * @return lambda function to compute features.
589 **/
590 template<class T>
591 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100592 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
593 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000594 {
595 /* Feature cache to be captured by lambda function. */
596 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
597
598 return [=](std::vector<int16_t>& audioDataWindow,
599 size_t index,
600 bool useCache,
601 size_t featuresOverlapIndex)
602 {
603 T* tensorData = tflite::GetTensorData<T>(inputTensor);
604 std::vector<T> features;
605
606 /* Reuse features from cache if cache is ready and sliding windows overlap.
607 * Overlap is in the beginning of sliding window with a size of a feature cache.
608 */
609 if (useCache && index < featureCache.size()) {
610 features = std::move(featureCache[index]);
611 } else {
612 features = std::move(compute(audioDataWindow));
613 }
614 auto size = features.size();
615 auto sizeBytes = sizeof(T) * size;
616 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
617
618 /* Start renewing cache as soon iteration goes out of the windows overlap. */
619 if (index >= featuresOverlapIndex) {
620 featureCache[index - featuresOverlapIndex] = std::move(features);
621 }
622 };
623 }
624
625 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100626 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
627 size_t cacheSize,
628 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
629
630 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
631 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000632 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100633 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000634
635 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100636 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
637 size_t cacheSize,
638 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000639
640 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100641 FeatureCalc<float>(TfLiteTensor* inputTensor,
642 size_t cacheSize,
643 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000644
645
646 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
647 GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
648 {
649 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
650
651 TfLiteQuantization quant = inputTensor->quantization;
652
653 if (kTfLiteAffineQuantization == quant.type) {
654
655 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
656 const float quantScale = quantParams->scale->data[0];
657 const int quantOffset = quantParams->zero_point->data[0];
658
659 switch (inputTensor->type) {
660 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100661 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
662 cacheSize,
663 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
664 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
665 quantScale,
666 quantOffset);
667 }
alexander3c798932021-03-26 21:42:19 +0000668 );
669 break;
670 }
671 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100672 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
673 cacheSize,
674 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
675 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
676 quantScale,
677 quantOffset);
678 }
alexander3c798932021-03-26 21:42:19 +0000679 );
680 break;
681 }
682 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100683 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
684 cacheSize,
685 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
686 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
687 quantScale,
688 quantOffset);
689 }
alexander3c798932021-03-26 21:42:19 +0000690 );
691 break;
692 }
693 default:
694 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
695 }
696
697
698 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100699 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
700 cacheSize,
701 [&mfcc](std::vector<int16_t>& audioDataWindow) {
702 return mfcc.MfccCompute(audioDataWindow);
703 });
alexander3c798932021-03-26 21:42:19 +0000704 }
705 return mfccFeatureCalc;
706 }
707} /* namespace app */
708} /* namespace arm */