blob: 7025d6d3ff0a07cacb0e1d4f5539b7d166cb3fc3 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "DsCnnModel.hpp"
24#include "DsCnnMfcc.hpp"
25#include "Classifier.hpp"
26#include "KwsResult.hpp"
27#include "Wav2LetterMfcc.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "Wav2LetterPostprocess.hpp"
30#include "AsrResult.hpp"
31#include "AsrClassifier.hpp"
32#include "OutputDecode.hpp"
33
34
35using KwsClassifier = arm::app::Classifier;
36
37namespace arm {
38namespace app {
39
40 enum AsrOutputReductionAxis {
41 AxisRow = 1,
42 AxisCol = 2
43 };
44
45 struct KWSOutput {
46 bool executionSuccess = false;
47 const int16_t* asrAudioStart = nullptr;
48 int32_t asrAudioSamples = 0;
49 };
50
51 /**
52 * @brief Helper function to increment current audio clip index
53 * @param[in,out] ctx pointer to the application context object
54 **/
alexanderc350cdc2021-04-29 20:36:09 +010055 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
alexander3c798932021-03-26 21:42:19 +000056
57 /**
58 * @brief Helper function to set the audio clip index
59 * @param[in,out] ctx pointer to the application context object
60 * @param[in] idx value to be set
61 * @return true if index is set, false otherwise
62 **/
alexanderc350cdc2021-04-29 20:36:09 +010063 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
alexander3c798932021-03-26 21:42:19 +000064
65 /**
66 * @brief Presents kws inference results using the data presentation
67 * object.
68 * @param[in] platform reference to the hal platform object
69 * @param[in] results vector of classification results to be displayed
70 * @param[in] infTimeMs inference time in milliseconds, if available
71 * Otherwise, this can be passed in as 0.
72 * @return true if successful, false otherwise
73 **/
alexanderc350cdc2021-04-29 20:36:09 +010074 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000075
76 /**
77 * @brief Presents asr inference results using the data presentation
78 * object.
79 * @param[in] platform reference to the hal platform object
80 * @param[in] results vector of classification results to be displayed
81 * @param[in] infTimeMs inference time in milliseconds, if available
82 * Otherwise, this can be passed in as 0.
83 * @return true if successful, false otherwise
84 **/
alexanderc350cdc2021-04-29 20:36:09 +010085 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000086
87 /**
88 * @brief Returns a function to perform feature calculation and populates input tensor data with
89 * MFCC data.
90 *
91 * Input tensor data type check is performed to choose correct MFCC feature data type.
92 * If tensor has an integer data type then original features are quantised.
93 *
94 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
95 *
96 * @param[in] mfcc MFCC feature calculator.
97 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
98 * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
99 *
100 * @return function function to be called providing audio sample and sliding window index.
101 **/
102 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
103 GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
104 TfLiteTensor* inputTensor,
105 size_t cacheSize);
106
107 /**
108 * @brief Performs the KWS pipeline.
109 * @param[in,out] ctx pointer to the application context object
110 *
111 * @return KWSOutput struct containing pointer to audio data where ASR should begin
112 * and how much data to process.
113 */
114 static KWSOutput doKws(ApplicationContext& ctx) {
115 constexpr uint32_t dataPsnTxtInfStartX = 20;
116 constexpr uint32_t dataPsnTxtInfStartY = 40;
117
118 constexpr int minTensorDims = static_cast<int>(
119 (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
120 arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
121
122 KWSOutput output;
123
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100124 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000125 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
126 if (!kwsModel.IsInited()) {
127 printf_err("KWS model has not been initialised\n");
128 return output;
129 }
130
131 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
132 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
133 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
134
135 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
136 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
137
138 if (!kwsInputTensor->dims) {
139 printf_err("Invalid input tensor dims\n");
140 return output;
141 } else if (kwsInputTensor->dims->size < minTensorDims) {
142 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
143 return output;
144 }
145
146 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
147 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
148
149 audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength);
150 kwsMfcc.Init();
151
152 /* Deduce the data length required for 1 KWS inference from the network parameters. */
153 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
154 (kwsFrameLength - kwsFrameStride);
155 auto kwsMfccWindowSize = kwsFrameLength;
156 auto kwsMfccWindowStride = kwsFrameStride;
157
158 /* We are choosing to move by half the window size => for a 1 second window size,
159 * this means an overlap of 0.5 seconds. */
160 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
161
162 info("KWS audio data window size %u\n", kwsAudioDataWindowSize);
163
164 /* Stride must be multiple of mfcc features window stride to re-use features. */
165 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
166 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
167 }
168
169 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
170
171 /* We expect to be sampling 1 second worth of data at a time
172 * NOTE: This is only used for time stamp calculation. */
173 const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
174
175 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
176
177 /* Creating a mfcc features sliding window for the data required for 1 inference. */
178 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
179 get_audio_array(currentIndex),
180 kwsAudioDataWindowSize, kwsMfccWindowSize,
181 kwsMfccWindowStride);
182
183 /* Creating a sliding window through the whole audio clip. */
184 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
185 get_audio_array(currentIndex),
186 get_audio_array_size(currentIndex),
187 kwsAudioDataWindowSize, kwsAudioDataStride);
188
189 /* Calculate number of the feature vectors in the window overlap region.
190 * These feature vectors will be reused.*/
191 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
192 - kwsMfccVectorsInAudioStride;
193
194 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
195 numberOfReusedFeatureVectors);
196
197 if (!kwsMfccFeatureCalc){
198 return output;
199 }
200
201 /* Container for KWS results. */
202 std::vector<arm::app::kws::KwsResult> kwsResults;
203
204 /* Display message on the LCD - inference running. */
205 auto& platform = ctx.Get<hal_platform&>("platform");
206 std::string str_inf{"Running KWS inference... "};
207 platform.data_psn->present_data_text(
208 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100209 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000210
211 info("Running KWS inference on audio clip %u => %s\n",
212 currentIndex, get_filename(currentIndex));
213
214 /* Start sliding through audio clip. */
215 while (audioDataSlider.HasNext()) {
216 const int16_t* inferenceWindow = audioDataSlider.Next();
217
218 /* We moved to the next window - set the features sliding to the new address. */
219 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
220
221 /* The first window does not have cache ready. */
222 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
223
224 /* Start calculating features inside one audio sliding window. */
225 while (kwsAudioMFCCWindowSlider.HasNext()) {
226 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
227 std::vector<int16_t> kwsMfccAudioData =
228 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
229
230 /* Compute features for this window and write them to input tensor. */
231 kwsMfccFeatureCalc(kwsMfccAudioData,
232 kwsAudioMFCCWindowSlider.Index(),
233 useCache,
234 kwsMfccVectorsInAudioStride);
235 }
236
237 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
238 audioDataSlider.TotalStrides() + 1);
239
240 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100241 if (!RunInference(kwsModel, profiler)) {
242 printf_err("KWS inference failed\n");
243 return output;
244 }
alexander3c798932021-03-26 21:42:19 +0000245
246 std::vector<ClassificationResult> kwsClassificationResult;
247 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
248
249 kwsClassifier.GetClassificationResults(
250 kwsOutputTensor, kwsClassificationResult,
251 ctx.Get<std::vector<std::string>&>("kwslabels"), 1);
252
253 kwsResults.emplace_back(
254 kws::KwsResult(
255 kwsClassificationResult,
256 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
257 audioDataSlider.Index(), kwsScoreThreshold)
258 );
259
260 /* Keyword detected. */
261 if (kwsClassificationResult[0].m_labelIdx == ctx.Get<uint32_t>("keywordindex")) {
262 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
263 output.asrAudioSamples = get_audio_array_size(currentIndex) -
264 (audioDataSlider.NextWindowStartIndex() -
265 kwsAudioDataStride + kwsAudioDataWindowSize);
266 break;
267 }
268
269#if VERIFY_TEST_OUTPUT
270 arm::app::DumpTensor(kwsOutputTensor);
271#endif /* VERIFY_TEST_OUTPUT */
272
273 } /* while (audioDataSlider.HasNext()) */
274
275 /* Erase. */
276 str_inf = std::string(str_inf.size(), ' ');
277 platform.data_psn->present_data_text(
278 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100279 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000280
alexanderc350cdc2021-04-29 20:36:09 +0100281 if (!PresentInferenceResult(platform, kwsResults)) {
alexander3c798932021-03-26 21:42:19 +0000282 return output;
283 }
284
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100285 profiler.PrintProfilingResult();
286
alexander3c798932021-03-26 21:42:19 +0000287 output.executionSuccess = true;
288 return output;
289 }
290
291 /**
292 * @brief Performs the ASR pipeline.
293 *
294 * @param ctx[in/out] pointer to the application context object
295 * @param kwsOutput[in] struct containing pointer to audio data where ASR should begin
296 * and how much data to process
297 * @return bool true if pipeline executed without failure
298 */
299 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
300 constexpr uint32_t dataPsnTxtInfStartX = 20;
301 constexpr uint32_t dataPsnTxtInfStartY = 40;
302
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100303 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000304 auto& platform = ctx.Get<hal_platform&>("platform");
305 platform.data_psn->clear(COLOR_BLACK);
306
307 /* Get model reference. */
308 auto& asrModel = ctx.Get<Model&>("asrmodel");
309 if (!asrModel.IsInited()) {
310 printf_err("ASR model has not been initialised\n");
311 return false;
312 }
313
314 /* Get score threshold to be applied for the classifier (post-inference). */
315 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
316
317 /* Dimensions of the tensor should have been verified by the callee. */
318 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
319 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
320 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
321
322 /* Populate ASR MFCC related parameters. */
323 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
324 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
325
326 /* Populate ASR inference context and inner lengths for input. */
327 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
328 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
329
330 /* Make sure the input tensor supports the above context and inner lengths. */
331 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
332 printf_err("ASR input rows not compatible with ctx length %u\n", asrInputCtxLen);
333 return false;
334 }
335
336 /* Audio data stride corresponds to inputInnerLen feature vectors. */
337 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
338 asrMfccParamsWinStride + (asrMfccParamsWinLen);
339 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
340 const float asrAudioParamsSecondsPerSample =
341 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
342
343 /* Get pre/post-processing objects */
344 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
345 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
346
347 /* Set default reduction axis for post-processing. */
348 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
349
350 /* Get the remaining audio buffer and respective size from KWS results. */
351 const int16_t* audioArr = kwsOutput.asrAudioStart;
352 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
353
354 /* Audio clip must have enough samples to produce 1 MFCC feature. */
355 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
356 if (audioArrSize < asrMfccParamsWinLen) {
357 printf_err("Not enough audio samples, minimum needed is %u\n", asrMfccParamsWinLen);
358 return false;
359 }
360
361 /* Initialise an audio slider. */
362 auto audioDataSlider = audio::ASRSlidingWindow<const int16_t>(
363 audioBuffer.data(),
364 audioBuffer.size(),
365 asrAudioParamsWinLen,
366 asrAudioParamsWinStride);
367
368 /* Declare a container for results. */
369 std::vector<arm::app::asr::AsrResult> asrResults;
370
371 /* Display message on the LCD - inference running. */
372 std::string str_inf{"Running ASR inference... "};
373 platform.data_psn->present_data_text(
374 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100375 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000376
377 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
378
379 /* Start sliding through audio clip. */
380 while (audioDataSlider.HasNext()) {
381
382 /* If not enough audio see how much can be sent for processing. */
383 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
384 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
385 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
386 }
387
388 const int16_t* asrInferenceWindow = audioDataSlider.Next();
389
390 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
391 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
392
alexander3c798932021-03-26 21:42:19 +0000393 /* Calculate MFCCs, deltas and populate the input tensor. */
394 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
395
alexander3c798932021-03-26 21:42:19 +0000396 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100397 if (!RunInference(asrModel, profiler)) {
398 printf_err("ASR inference failed\n");
399 return false;
400 }
alexander3c798932021-03-26 21:42:19 +0000401
402 /* Post-process. */
403 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
404
405 /* Get results. */
406 std::vector<ClassificationResult> asrClassificationResult;
407 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
408 asrClassifier.GetClassificationResults(
409 asrOutputTensor, asrClassificationResult,
410 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
411
412 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
413 (audioDataSlider.Index() *
414 asrAudioParamsSecondsPerSample *
415 asrAudioParamsWinStride),
416 audioDataSlider.Index(), asrScoreThreshold));
417
418#if VERIFY_TEST_OUTPUT
419 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
420#endif /* VERIFY_TEST_OUTPUT */
421
422 /* Erase */
423 str_inf = std::string(str_inf.size(), ' ');
424 platform.data_psn->present_data_text(
425 str_inf.c_str(), str_inf.size(),
426 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
427 }
alexanderc350cdc2021-04-29 20:36:09 +0100428 if (!PresentInferenceResult(platform, asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000429 return false;
430 }
431
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100432 profiler.PrintProfilingResult();
433
alexander3c798932021-03-26 21:42:19 +0000434 return true;
435 }
436
437 /* Audio inference classification handler. */
438 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
439 {
440 auto& platform = ctx.Get<hal_platform&>("platform");
441 platform.data_psn->clear(COLOR_BLACK);
442
443 /* If the request has a valid size, set the audio index. */
444 if (clipIndex < NUMBER_OF_FILES) {
alexanderc350cdc2021-04-29 20:36:09 +0100445 if (!SetAppCtxClipIdx(ctx, clipIndex)) {
alexander3c798932021-03-26 21:42:19 +0000446 return false;
447 }
448 }
449
450 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
451
452 do {
453 KWSOutput kwsOutput = doKws(ctx);
454 if (!kwsOutput.executionSuccess) {
455 return false;
456 }
457
458 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
459 info("Keyword spotted\n");
460 if(!doAsr(ctx, kwsOutput)) {
461 printf_err("ASR failed");
462 return false;
463 }
464 }
465
alexanderc350cdc2021-04-29 20:36:09 +0100466 IncrementAppCtxClipIdx(ctx);
alexander3c798932021-03-26 21:42:19 +0000467
468 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
469
470 return true;
471 }
472
alexanderc350cdc2021-04-29 20:36:09 +0100473 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
alexander3c798932021-03-26 21:42:19 +0000474 {
475 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
476
477 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
478 ctx.Set<uint32_t>("clipIndex", 0);
479 return;
480 }
481 ++curAudioIdx;
482 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
483 }
484
alexanderc350cdc2021-04-29 20:36:09 +0100485 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx)
alexander3c798932021-03-26 21:42:19 +0000486 {
487 if (idx >= NUMBER_OF_FILES) {
488 printf_err("Invalid idx %u (expected less than %u)\n",
489 idx, NUMBER_OF_FILES);
490 return false;
491 }
492 ctx.Set<uint32_t>("clipIndex", idx);
493 return true;
494 }
495
alexanderc350cdc2021-04-29 20:36:09 +0100496 static bool PresentInferenceResult(hal_platform& platform,
497 std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000498 {
499 constexpr uint32_t dataPsnTxtStartX1 = 20;
500 constexpr uint32_t dataPsnTxtStartY1 = 30;
501 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
502
503 platform.data_psn->set_text_color(COLOR_GREEN);
504
505 /* Display each result. */
506 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
507
508 for (uint32_t i = 0; i < results.size(); ++i) {
509
510 std::string topKeyword{"<none>"};
511 float score = 0.f;
512
alexanderc350cdc2021-04-29 20:36:09 +0100513 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000514 topKeyword = results[i].m_resultVec[0].m_label;
515 score = results[i].m_resultVec[0].m_normalisedVal;
516 }
517
518 std::string resultStr =
519 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
520 std::string{"s: "} + topKeyword + std::string{" ("} +
521 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
522
523 platform.data_psn->present_data_text(
524 resultStr.c_str(), resultStr.size(),
525 dataPsnTxtStartX1, rowIdx1, 0);
526 rowIdx1 += dataPsnTxtYIncr;
527
528 info("For timestamp: %f (inference #: %u); threshold: %f\n",
529 results[i].m_timeStamp, results[i].m_inferenceNumber,
530 results[i].m_threshold);
531 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
532 info("\t\tlabel @ %u: %s, score: %f\n", j,
533 results[i].m_resultVec[j].m_label.c_str(),
534 results[i].m_resultVec[j].m_normalisedVal);
535 }
536 }
537
538 return true;
539 }
540
alexanderc350cdc2021-04-29 20:36:09 +0100541 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000542 {
543 constexpr uint32_t dataPsnTxtStartX1 = 20;
544 constexpr uint32_t dataPsnTxtStartY1 = 80;
545 constexpr bool allow_multiple_lines = true;
546
547 platform.data_psn->set_text_color(COLOR_GREEN);
548
549 /* Results from multiple inferences should be combined before processing. */
550 std::vector<arm::app::ClassificationResult> combinedResults;
551 for (auto& result : results) {
552 combinedResults.insert(combinedResults.end(),
553 result.m_resultVec.begin(),
554 result.m_resultVec.end());
555 }
556
557 for (auto& result : results) {
558 /* Get the final result string using the decoder. */
559 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
560
561 info("Result for inf %u: %s\n", result.m_inferenceNumber,
562 infResultStr.c_str());
563 }
564
565 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
566
567 platform.data_psn->present_data_text(
568 finalResultStr.c_str(), finalResultStr.size(),
569 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
570
571 info("Final result: %s\n", finalResultStr.c_str());
572 return true;
573 }
574
575 /**
576 * @brief Generic feature calculator factory.
577 *
578 * Returns lambda function to compute features using features cache.
579 * Real features math is done by a lambda function provided as a parameter.
580 * Features are written to input tensor memory.
581 *
582 * @tparam T feature vector type.
583 * @param inputTensor model input tensor pointer.
584 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
585 * @param compute features calculator function.
586 * @return lambda function to compute features.
587 **/
588 template<class T>
589 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100590 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
591 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000592 {
593 /* Feature cache to be captured by lambda function. */
594 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
595
596 return [=](std::vector<int16_t>& audioDataWindow,
597 size_t index,
598 bool useCache,
599 size_t featuresOverlapIndex)
600 {
601 T* tensorData = tflite::GetTensorData<T>(inputTensor);
602 std::vector<T> features;
603
604 /* Reuse features from cache if cache is ready and sliding windows overlap.
605 * Overlap is in the beginning of sliding window with a size of a feature cache.
606 */
607 if (useCache && index < featureCache.size()) {
608 features = std::move(featureCache[index]);
609 } else {
610 features = std::move(compute(audioDataWindow));
611 }
612 auto size = features.size();
613 auto sizeBytes = sizeof(T) * size;
614 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
615
616 /* Start renewing cache as soon iteration goes out of the windows overlap. */
617 if (index >= featuresOverlapIndex) {
618 featureCache[index - featuresOverlapIndex] = std::move(features);
619 }
620 };
621 }
622
623 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100624 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
625 size_t cacheSize,
626 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
627
628 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
629 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000630 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100631 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000632
633 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100634 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
635 size_t cacheSize,
636 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000637
638 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100639 FeatureCalc<float>(TfLiteTensor* inputTensor,
640 size_t cacheSize,
641 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000642
643
644 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
645 GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
646 {
647 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
648
649 TfLiteQuantization quant = inputTensor->quantization;
650
651 if (kTfLiteAffineQuantization == quant.type) {
652
653 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
654 const float quantScale = quantParams->scale->data[0];
655 const int quantOffset = quantParams->zero_point->data[0];
656
657 switch (inputTensor->type) {
658 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100659 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
660 cacheSize,
661 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
662 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
663 quantScale,
664 quantOffset);
665 }
alexander3c798932021-03-26 21:42:19 +0000666 );
667 break;
668 }
669 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100670 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
671 cacheSize,
672 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
673 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
674 quantScale,
675 quantOffset);
676 }
alexander3c798932021-03-26 21:42:19 +0000677 );
678 break;
679 }
680 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100681 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
682 cacheSize,
683 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
684 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
685 quantScale,
686 quantOffset);
687 }
alexander3c798932021-03-26 21:42:19 +0000688 );
689 break;
690 }
691 default:
692 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
693 }
694
695
696 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100697 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
698 cacheSize,
699 [&mfcc](std::vector<int16_t>& audioDataWindow) {
700 return mfcc.MfccCompute(audioDataWindow);
701 });
alexander3c798932021-03-26 21:42:19 +0000702 }
703 return mfccFeatureCalc;
704 }
705} /* namespace app */
706} /* namespace arm */