blob: d598de614b1603d5d2a83bbf2bbdfa96a10ce587 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtoned35a6f2022-02-14 11:55:35 +00002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000022#include "ImageUtils.hpp"
alexander3c798932021-03-26 21:42:19 +000023#include "UseCaseCommonUtils.hpp"
Kshitij Sisodia76a15802021-12-24 11:05:11 +000024#include "MicroNetKwsModel.hpp"
25#include "MicroNetKwsMfcc.hpp"
alexander3c798932021-03-26 21:42:19 +000026#include "Classifier.hpp"
27#include "KwsResult.hpp"
28#include "Wav2LetterMfcc.hpp"
29#include "Wav2LetterPreprocess.hpp"
30#include "Wav2LetterPostprocess.hpp"
31#include "AsrResult.hpp"
32#include "AsrClassifier.hpp"
33#include "OutputDecode.hpp"
alexander31ae9f02022-02-10 16:15:54 +000034#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000035
36
37using KwsClassifier = arm::app::Classifier;
38
39namespace arm {
40namespace app {
41
42 enum AsrOutputReductionAxis {
43 AxisRow = 1,
44 AxisCol = 2
45 };
46
47 struct KWSOutput {
48 bool executionSuccess = false;
49 const int16_t* asrAudioStart = nullptr;
50 int32_t asrAudioSamples = 0;
51 };
52
53 /**
alexander3c798932021-03-26 21:42:19 +000054 * @brief Presents kws inference results using the data presentation
55 * object.
56 * @param[in] platform reference to the hal platform object
57 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000058 * @return true if successful, false otherwise
59 **/
alexanderc350cdc2021-04-29 20:36:09 +010060 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000061
62 /**
63 * @brief Presents asr inference results using the data presentation
64 * object.
65 * @param[in] platform reference to the hal platform object
66 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000067 * @return true if successful, false otherwise
68 **/
alexanderc350cdc2021-04-29 20:36:09 +010069 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000070
71 /**
72 * @brief Returns a function to perform feature calculation and populates input tensor data with
73 * MFCC data.
74 *
75 * Input tensor data type check is performed to choose correct MFCC feature data type.
76 * If tensor has an integer data type then original features are quantised.
77 *
78 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
79 *
80 * @param[in] mfcc MFCC feature calculator.
81 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
Kshitij Sisodia76a15802021-12-24 11:05:11 +000082 * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
alexander3c798932021-03-26 21:42:19 +000083 *
84 * @return function function to be called providing audio sample and sliding window index.
85 **/
86 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
Kshitij Sisodia76a15802021-12-24 11:05:11 +000087 GetFeatureCalculator(audio::MicroNetMFCC& mfcc,
alexander3c798932021-03-26 21:42:19 +000088 TfLiteTensor* inputTensor,
89 size_t cacheSize);
90
91 /**
92 * @brief Performs the KWS pipeline.
93 * @param[in,out] ctx pointer to the application context object
94 *
95 * @return KWSOutput struct containing pointer to audio data where ASR should begin
96 * and how much data to process.
97 */
98 static KWSOutput doKws(ApplicationContext& ctx) {
99 constexpr uint32_t dataPsnTxtInfStartX = 20;
100 constexpr uint32_t dataPsnTxtInfStartY = 40;
101
102 constexpr int minTensorDims = static_cast<int>(
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000103 (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
104 arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
alexander3c798932021-03-26 21:42:19 +0000105
106 KWSOutput output;
107
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100108 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000109 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
110 if (!kwsModel.IsInited()) {
111 printf_err("KWS model has not been initialised\n");
112 return output;
113 }
114
115 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
116 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
117 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
118
119 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
120 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
121
122 if (!kwsInputTensor->dims) {
123 printf_err("Invalid input tensor dims\n");
124 return output;
125 } else if (kwsInputTensor->dims->size < minTensorDims) {
126 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
127 return output;
128 }
129
130 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
131 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
132
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000133 audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength);
alexander3c798932021-03-26 21:42:19 +0000134 kwsMfcc.Init();
135
136 /* Deduce the data length required for 1 KWS inference from the network parameters. */
137 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
138 (kwsFrameLength - kwsFrameStride);
139 auto kwsMfccWindowSize = kwsFrameLength;
140 auto kwsMfccWindowStride = kwsFrameStride;
141
142 /* We are choosing to move by half the window size => for a 1 second window size,
143 * this means an overlap of 0.5 seconds. */
144 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
145
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100146 info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000147
148 /* Stride must be multiple of mfcc features window stride to re-use features. */
149 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
150 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
151 }
152
153 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
154
155 /* We expect to be sampling 1 second worth of data at a time
156 * NOTE: This is only used for time stamp calculation. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000157 const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000158
159 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
160
161 /* Creating a mfcc features sliding window for the data required for 1 inference. */
162 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
163 get_audio_array(currentIndex),
164 kwsAudioDataWindowSize, kwsMfccWindowSize,
165 kwsMfccWindowStride);
166
167 /* Creating a sliding window through the whole audio clip. */
168 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
169 get_audio_array(currentIndex),
170 get_audio_array_size(currentIndex),
171 kwsAudioDataWindowSize, kwsAudioDataStride);
172
173 /* Calculate number of the feature vectors in the window overlap region.
174 * These feature vectors will be reused.*/
175 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
176 - kwsMfccVectorsInAudioStride;
177
178 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
179 numberOfReusedFeatureVectors);
180
181 if (!kwsMfccFeatureCalc){
182 return output;
183 }
184
185 /* Container for KWS results. */
186 std::vector<arm::app::kws::KwsResult> kwsResults;
187
188 /* Display message on the LCD - inference running. */
189 auto& platform = ctx.Get<hal_platform&>("platform");
190 std::string str_inf{"Running KWS inference... "};
191 platform.data_psn->present_data_text(
192 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100193 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000194
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100195 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
alexander3c798932021-03-26 21:42:19 +0000196 currentIndex, get_filename(currentIndex));
197
198 /* Start sliding through audio clip. */
199 while (audioDataSlider.HasNext()) {
200 const int16_t* inferenceWindow = audioDataSlider.Next();
201
202 /* We moved to the next window - set the features sliding to the new address. */
203 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
204
205 /* The first window does not have cache ready. */
206 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
207
208 /* Start calculating features inside one audio sliding window. */
209 while (kwsAudioMFCCWindowSlider.HasNext()) {
210 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
211 std::vector<int16_t> kwsMfccAudioData =
212 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
213
214 /* Compute features for this window and write them to input tensor. */
215 kwsMfccFeatureCalc(kwsMfccAudioData,
216 kwsAudioMFCCWindowSlider.Index(),
217 useCache,
218 kwsMfccVectorsInAudioStride);
219 }
220
221 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
222 audioDataSlider.TotalStrides() + 1);
223
224 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100225 if (!RunInference(kwsModel, profiler)) {
226 printf_err("KWS inference failed\n");
227 return output;
228 }
alexander3c798932021-03-26 21:42:19 +0000229
230 std::vector<ClassificationResult> kwsClassificationResult;
231 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
232
233 kwsClassifier.GetClassificationResults(
234 kwsOutputTensor, kwsClassificationResult,
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000235 ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true);
alexander3c798932021-03-26 21:42:19 +0000236
237 kwsResults.emplace_back(
238 kws::KwsResult(
239 kwsClassificationResult,
240 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
241 audioDataSlider.Index(), kwsScoreThreshold)
242 );
243
244 /* Keyword detected. */
Liam Barryb5b32d32021-12-30 11:35:00 +0000245 if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) {
alexander3c798932021-03-26 21:42:19 +0000246 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
247 output.asrAudioSamples = get_audio_array_size(currentIndex) -
248 (audioDataSlider.NextWindowStartIndex() -
249 kwsAudioDataStride + kwsAudioDataWindowSize);
250 break;
251 }
252
253#if VERIFY_TEST_OUTPUT
254 arm::app::DumpTensor(kwsOutputTensor);
255#endif /* VERIFY_TEST_OUTPUT */
256
257 } /* while (audioDataSlider.HasNext()) */
258
259 /* Erase. */
260 str_inf = std::string(str_inf.size(), ' ');
261 platform.data_psn->present_data_text(
262 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100263 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000264
alexanderc350cdc2021-04-29 20:36:09 +0100265 if (!PresentInferenceResult(platform, kwsResults)) {
alexander3c798932021-03-26 21:42:19 +0000266 return output;
267 }
268
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100269 profiler.PrintProfilingResult();
270
alexander3c798932021-03-26 21:42:19 +0000271 output.executionSuccess = true;
272 return output;
273 }
274
275 /**
276 * @brief Performs the ASR pipeline.
277 *
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100278 * @param[in,out] ctx pointer to the application context object
279 * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin
alexander3c798932021-03-26 21:42:19 +0000280 * and how much data to process
281 * @return bool true if pipeline executed without failure
282 */
283 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
284 constexpr uint32_t dataPsnTxtInfStartX = 20;
285 constexpr uint32_t dataPsnTxtInfStartY = 40;
286
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100287 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000288 auto& platform = ctx.Get<hal_platform&>("platform");
289 platform.data_psn->clear(COLOR_BLACK);
290
291 /* Get model reference. */
292 auto& asrModel = ctx.Get<Model&>("asrmodel");
293 if (!asrModel.IsInited()) {
294 printf_err("ASR model has not been initialised\n");
295 return false;
296 }
297
298 /* Get score threshold to be applied for the classifier (post-inference). */
299 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
300
301 /* Dimensions of the tensor should have been verified by the callee. */
302 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
303 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
304 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
305
306 /* Populate ASR MFCC related parameters. */
307 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
308 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
309
310 /* Populate ASR inference context and inner lengths for input. */
311 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
312 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
313
314 /* Make sure the input tensor supports the above context and inner lengths. */
315 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100316 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
317 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000318 return false;
319 }
320
321 /* Audio data stride corresponds to inputInnerLen feature vectors. */
322 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
323 asrMfccParamsWinStride + (asrMfccParamsWinLen);
324 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
325 const float asrAudioParamsSecondsPerSample =
326 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
327
328 /* Get pre/post-processing objects */
329 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
330 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
331
332 /* Set default reduction axis for post-processing. */
333 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
334
335 /* Get the remaining audio buffer and respective size from KWS results. */
336 const int16_t* audioArr = kwsOutput.asrAudioStart;
337 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
338
339 /* Audio clip must have enough samples to produce 1 MFCC feature. */
340 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
341 if (audioArrSize < asrMfccParamsWinLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100342 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
343 asrMfccParamsWinLen);
alexander3c798932021-03-26 21:42:19 +0000344 return false;
345 }
346
347 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100348 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000349 audioBuffer.data(),
350 audioBuffer.size(),
351 asrAudioParamsWinLen,
352 asrAudioParamsWinStride);
353
354 /* Declare a container for results. */
355 std::vector<arm::app::asr::AsrResult> asrResults;
356
357 /* Display message on the LCD - inference running. */
358 std::string str_inf{"Running ASR inference... "};
359 platform.data_psn->present_data_text(
360 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100361 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000362
363 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
364
365 /* Start sliding through audio clip. */
366 while (audioDataSlider.HasNext()) {
367
368 /* If not enough audio see how much can be sent for processing. */
369 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
370 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
371 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
372 }
373
374 const int16_t* asrInferenceWindow = audioDataSlider.Next();
375
376 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
377 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
378
alexander3c798932021-03-26 21:42:19 +0000379 /* Calculate MFCCs, deltas and populate the input tensor. */
380 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
381
alexander3c798932021-03-26 21:42:19 +0000382 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100383 if (!RunInference(asrModel, profiler)) {
384 printf_err("ASR inference failed\n");
385 return false;
386 }
alexander3c798932021-03-26 21:42:19 +0000387
388 /* Post-process. */
389 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
390
391 /* Get results. */
392 std::vector<ClassificationResult> asrClassificationResult;
393 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
394 asrClassifier.GetClassificationResults(
395 asrOutputTensor, asrClassificationResult,
396 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
397
398 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
399 (audioDataSlider.Index() *
400 asrAudioParamsSecondsPerSample *
401 asrAudioParamsWinStride),
402 audioDataSlider.Index(), asrScoreThreshold));
403
404#if VERIFY_TEST_OUTPUT
405 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
406#endif /* VERIFY_TEST_OUTPUT */
407
408 /* Erase */
409 str_inf = std::string(str_inf.size(), ' ');
410 platform.data_psn->present_data_text(
411 str_inf.c_str(), str_inf.size(),
412 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
413 }
alexanderc350cdc2021-04-29 20:36:09 +0100414 if (!PresentInferenceResult(platform, asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000415 return false;
416 }
417
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100418 profiler.PrintProfilingResult();
419
alexander3c798932021-03-26 21:42:19 +0000420 return true;
421 }
422
423 /* Audio inference classification handler. */
424 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
425 {
426 auto& platform = ctx.Get<hal_platform&>("platform");
427 platform.data_psn->clear(COLOR_BLACK);
428
429 /* If the request has a valid size, set the audio index. */
430 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100431 if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) {
alexander3c798932021-03-26 21:42:19 +0000432 return false;
433 }
434 }
435
436 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
437
438 do {
439 KWSOutput kwsOutput = doKws(ctx);
440 if (!kwsOutput.executionSuccess) {
441 return false;
442 }
443
444 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
445 info("Keyword spotted\n");
446 if(!doAsr(ctx, kwsOutput)) {
447 printf_err("ASR failed");
448 return false;
449 }
450 }
451
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100452 IncrementAppCtxIfmIdx(ctx,"kws_asr");
alexander3c798932021-03-26 21:42:19 +0000453
454 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
455
456 return true;
457 }
458
alexander3c798932021-03-26 21:42:19 +0000459
alexanderc350cdc2021-04-29 20:36:09 +0100460 static bool PresentInferenceResult(hal_platform& platform,
461 std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000462 {
463 constexpr uint32_t dataPsnTxtStartX1 = 20;
464 constexpr uint32_t dataPsnTxtStartY1 = 30;
465 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
466
467 platform.data_psn->set_text_color(COLOR_GREEN);
468
469 /* Display each result. */
470 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
471
472 for (uint32_t i = 0; i < results.size(); ++i) {
473
474 std::string topKeyword{"<none>"};
475 float score = 0.f;
476
alexanderc350cdc2021-04-29 20:36:09 +0100477 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000478 topKeyword = results[i].m_resultVec[0].m_label;
479 score = results[i].m_resultVec[0].m_normalisedVal;
480 }
481
482 std::string resultStr =
483 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
484 std::string{"s: "} + topKeyword + std::string{" ("} +
485 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
486
487 platform.data_psn->present_data_text(
488 resultStr.c_str(), resultStr.size(),
489 dataPsnTxtStartX1, rowIdx1, 0);
490 rowIdx1 += dataPsnTxtYIncr;
491
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100492 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
alexander3c798932021-03-26 21:42:19 +0000493 results[i].m_timeStamp, results[i].m_inferenceNumber,
494 results[i].m_threshold);
495 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100496 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
alexander3c798932021-03-26 21:42:19 +0000497 results[i].m_resultVec[j].m_label.c_str(),
498 results[i].m_resultVec[j].m_normalisedVal);
499 }
500 }
501
502 return true;
503 }
504
alexanderc350cdc2021-04-29 20:36:09 +0100505 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000506 {
507 constexpr uint32_t dataPsnTxtStartX1 = 20;
508 constexpr uint32_t dataPsnTxtStartY1 = 80;
509 constexpr bool allow_multiple_lines = true;
510
511 platform.data_psn->set_text_color(COLOR_GREEN);
512
513 /* Results from multiple inferences should be combined before processing. */
514 std::vector<arm::app::ClassificationResult> combinedResults;
515 for (auto& result : results) {
516 combinedResults.insert(combinedResults.end(),
517 result.m_resultVec.begin(),
518 result.m_resultVec.end());
519 }
520
521 for (auto& result : results) {
522 /* Get the final result string using the decoder. */
523 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
524
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100525 info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber,
alexander3c798932021-03-26 21:42:19 +0000526 infResultStr.c_str());
527 }
528
529 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
530
531 platform.data_psn->present_data_text(
532 finalResultStr.c_str(), finalResultStr.size(),
533 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
534
535 info("Final result: %s\n", finalResultStr.c_str());
536 return true;
537 }
538
539 /**
540 * @brief Generic feature calculator factory.
541 *
542 * Returns lambda function to compute features using features cache.
543 * Real features math is done by a lambda function provided as a parameter.
544 * Features are written to input tensor memory.
545 *
546 * @tparam T feature vector type.
547 * @param inputTensor model input tensor pointer.
548 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
549 * @param compute features calculator function.
550 * @return lambda function to compute features.
551 **/
552 template<class T>
553 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100554 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
555 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000556 {
557 /* Feature cache to be captured by lambda function. */
558 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
559
560 return [=](std::vector<int16_t>& audioDataWindow,
561 size_t index,
562 bool useCache,
563 size_t featuresOverlapIndex)
564 {
565 T* tensorData = tflite::GetTensorData<T>(inputTensor);
566 std::vector<T> features;
567
568 /* Reuse features from cache if cache is ready and sliding windows overlap.
569 * Overlap is in the beginning of sliding window with a size of a feature cache.
570 */
571 if (useCache && index < featureCache.size()) {
572 features = std::move(featureCache[index]);
573 } else {
574 features = std::move(compute(audioDataWindow));
575 }
576 auto size = features.size();
577 auto sizeBytes = sizeof(T) * size;
578 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
579
580 /* Start renewing cache as soon iteration goes out of the windows overlap. */
581 if (index >= featuresOverlapIndex) {
582 featureCache[index - featuresOverlapIndex] = std::move(features);
583 }
584 };
585 }
586
587 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100588 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
589 size_t cacheSize,
590 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
591
592 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
593 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000594 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100595 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000596
597 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100598 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
599 size_t cacheSize,
600 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000601
602 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100603 FeatureCalc<float>(TfLiteTensor* inputTensor,
604 size_t cacheSize,
605 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000606
607
608 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000609 GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
alexander3c798932021-03-26 21:42:19 +0000610 {
611 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
612
613 TfLiteQuantization quant = inputTensor->quantization;
614
615 if (kTfLiteAffineQuantization == quant.type) {
616
617 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
618 const float quantScale = quantParams->scale->data[0];
619 const int quantOffset = quantParams->zero_point->data[0];
620
621 switch (inputTensor->type) {
622 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100623 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
624 cacheSize,
625 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
626 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
627 quantScale,
628 quantOffset);
629 }
alexander3c798932021-03-26 21:42:19 +0000630 );
631 break;
632 }
633 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100634 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
635 cacheSize,
636 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
637 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
638 quantScale,
639 quantOffset);
640 }
alexander3c798932021-03-26 21:42:19 +0000641 );
642 break;
643 }
644 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100645 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
646 cacheSize,
647 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
648 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
649 quantScale,
650 quantOffset);
651 }
alexander3c798932021-03-26 21:42:19 +0000652 );
653 break;
654 }
655 default:
656 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
657 }
658
659
660 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100661 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
662 cacheSize,
663 [&mfcc](std::vector<int16_t>& audioDataWindow) {
664 return mfcc.MfccCompute(audioDataWindow);
665 });
alexander3c798932021-03-26 21:42:19 +0000666 }
667 return mfccFeatureCalc;
668 }
669} /* namespace app */
670} /* namespace arm */