blob: 1e1a4000e1441689e06fd38f06ceae716e5e7df2 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtoned35a6f2022-02-14 11:55:35 +00002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000022#include "ImageUtils.hpp"
alexander3c798932021-03-26 21:42:19 +000023#include "UseCaseCommonUtils.hpp"
Kshitij Sisodia76a15802021-12-24 11:05:11 +000024#include "MicroNetKwsModel.hpp"
25#include "MicroNetKwsMfcc.hpp"
alexander3c798932021-03-26 21:42:19 +000026#include "Classifier.hpp"
27#include "KwsResult.hpp"
28#include "Wav2LetterMfcc.hpp"
29#include "Wav2LetterPreprocess.hpp"
30#include "Wav2LetterPostprocess.hpp"
31#include "AsrResult.hpp"
32#include "AsrClassifier.hpp"
33#include "OutputDecode.hpp"
alexander31ae9f02022-02-10 16:15:54 +000034#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000035
36
37using KwsClassifier = arm::app::Classifier;
38
39namespace arm {
40namespace app {
41
42 enum AsrOutputReductionAxis {
43 AxisRow = 1,
44 AxisCol = 2
45 };
46
47 struct KWSOutput {
48 bool executionSuccess = false;
49 const int16_t* asrAudioStart = nullptr;
50 int32_t asrAudioSamples = 0;
51 };
52
53 /**
alexander3c798932021-03-26 21:42:19 +000054 * @brief Presents kws inference results using the data presentation
55 * object.
alexander3c798932021-03-26 21:42:19 +000056 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000057 * @return true if successful, false otherwise
58 **/
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010059 static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000060
61 /**
62 * @brief Presents asr inference results using the data presentation
63 * object.
64 * @param[in] platform reference to the hal platform object
65 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000066 * @return true if successful, false otherwise
67 **/
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010068 static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000069
70 /**
71 * @brief Returns a function to perform feature calculation and populates input tensor data with
72 * MFCC data.
73 *
74 * Input tensor data type check is performed to choose correct MFCC feature data type.
75 * If tensor has an integer data type then original features are quantised.
76 *
77 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
78 *
79 * @param[in] mfcc MFCC feature calculator.
80 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
Kshitij Sisodia76a15802021-12-24 11:05:11 +000081 * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
alexander3c798932021-03-26 21:42:19 +000082 *
83 * @return function function to be called providing audio sample and sliding window index.
84 **/
85 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
Kshitij Sisodia76a15802021-12-24 11:05:11 +000086 GetFeatureCalculator(audio::MicroNetMFCC& mfcc,
alexander3c798932021-03-26 21:42:19 +000087 TfLiteTensor* inputTensor,
88 size_t cacheSize);
89
90 /**
91 * @brief Performs the KWS pipeline.
92 * @param[in,out] ctx pointer to the application context object
93 *
94 * @return KWSOutput struct containing pointer to audio data where ASR should begin
95 * and how much data to process.
96 */
97 static KWSOutput doKws(ApplicationContext& ctx) {
98 constexpr uint32_t dataPsnTxtInfStartX = 20;
99 constexpr uint32_t dataPsnTxtInfStartY = 40;
100
101 constexpr int minTensorDims = static_cast<int>(
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000102 (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
103 arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
alexander3c798932021-03-26 21:42:19 +0000104
105 KWSOutput output;
106
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100107 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000108 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
109 if (!kwsModel.IsInited()) {
110 printf_err("KWS model has not been initialised\n");
111 return output;
112 }
113
114 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
115 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
116 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
117
118 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
119 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
120
121 if (!kwsInputTensor->dims) {
122 printf_err("Invalid input tensor dims\n");
123 return output;
124 } else if (kwsInputTensor->dims->size < minTensorDims) {
125 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
126 return output;
127 }
128
129 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
130 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
131
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000132 audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength);
alexander3c798932021-03-26 21:42:19 +0000133 kwsMfcc.Init();
134
135 /* Deduce the data length required for 1 KWS inference from the network parameters. */
136 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
137 (kwsFrameLength - kwsFrameStride);
138 auto kwsMfccWindowSize = kwsFrameLength;
139 auto kwsMfccWindowStride = kwsFrameStride;
140
141 /* We are choosing to move by half the window size => for a 1 second window size,
142 * this means an overlap of 0.5 seconds. */
143 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
144
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100145 info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000146
147 /* Stride must be multiple of mfcc features window stride to re-use features. */
148 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
149 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
150 }
151
152 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
153
154 /* We expect to be sampling 1 second worth of data at a time
155 * NOTE: This is only used for time stamp calculation. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000156 const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000157
158 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
159
160 /* Creating a mfcc features sliding window for the data required for 1 inference. */
161 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
162 get_audio_array(currentIndex),
163 kwsAudioDataWindowSize, kwsMfccWindowSize,
164 kwsMfccWindowStride);
165
166 /* Creating a sliding window through the whole audio clip. */
167 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
168 get_audio_array(currentIndex),
169 get_audio_array_size(currentIndex),
170 kwsAudioDataWindowSize, kwsAudioDataStride);
171
172 /* Calculate number of the feature vectors in the window overlap region.
173 * These feature vectors will be reused.*/
174 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
175 - kwsMfccVectorsInAudioStride;
176
177 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
178 numberOfReusedFeatureVectors);
179
180 if (!kwsMfccFeatureCalc){
181 return output;
182 }
183
184 /* Container for KWS results. */
185 std::vector<arm::app::kws::KwsResult> kwsResults;
186
187 /* Display message on the LCD - inference running. */
alexander3c798932021-03-26 21:42:19 +0000188 std::string str_inf{"Running KWS inference... "};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100189 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000190 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100191 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000192
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100193 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
alexander3c798932021-03-26 21:42:19 +0000194 currentIndex, get_filename(currentIndex));
195
196 /* Start sliding through audio clip. */
197 while (audioDataSlider.HasNext()) {
198 const int16_t* inferenceWindow = audioDataSlider.Next();
199
200 /* We moved to the next window - set the features sliding to the new address. */
201 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
202
203 /* The first window does not have cache ready. */
204 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
205
206 /* Start calculating features inside one audio sliding window. */
207 while (kwsAudioMFCCWindowSlider.HasNext()) {
208 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
209 std::vector<int16_t> kwsMfccAudioData =
210 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
211
212 /* Compute features for this window and write them to input tensor. */
213 kwsMfccFeatureCalc(kwsMfccAudioData,
214 kwsAudioMFCCWindowSlider.Index(),
215 useCache,
216 kwsMfccVectorsInAudioStride);
217 }
218
219 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
220 audioDataSlider.TotalStrides() + 1);
221
222 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100223 if (!RunInference(kwsModel, profiler)) {
224 printf_err("KWS inference failed\n");
225 return output;
226 }
alexander3c798932021-03-26 21:42:19 +0000227
228 std::vector<ClassificationResult> kwsClassificationResult;
229 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
230
231 kwsClassifier.GetClassificationResults(
232 kwsOutputTensor, kwsClassificationResult,
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000233 ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true);
alexander3c798932021-03-26 21:42:19 +0000234
235 kwsResults.emplace_back(
236 kws::KwsResult(
237 kwsClassificationResult,
238 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
239 audioDataSlider.Index(), kwsScoreThreshold)
240 );
241
242 /* Keyword detected. */
Liam Barryb5b32d32021-12-30 11:35:00 +0000243 if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) {
alexander3c798932021-03-26 21:42:19 +0000244 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
245 output.asrAudioSamples = get_audio_array_size(currentIndex) -
246 (audioDataSlider.NextWindowStartIndex() -
247 kwsAudioDataStride + kwsAudioDataWindowSize);
248 break;
249 }
250
251#if VERIFY_TEST_OUTPUT
252 arm::app::DumpTensor(kwsOutputTensor);
253#endif /* VERIFY_TEST_OUTPUT */
254
255 } /* while (audioDataSlider.HasNext()) */
256
257 /* Erase. */
258 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100259 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000260 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100261 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000262
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100263 if (!PresentInferenceResult(kwsResults)) {
alexander3c798932021-03-26 21:42:19 +0000264 return output;
265 }
266
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100267 profiler.PrintProfilingResult();
268
alexander3c798932021-03-26 21:42:19 +0000269 output.executionSuccess = true;
270 return output;
271 }
272
273 /**
274 * @brief Performs the ASR pipeline.
275 *
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100276 * @param[in,out] ctx pointer to the application context object
277 * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin
alexander3c798932021-03-26 21:42:19 +0000278 * and how much data to process
279 * @return bool true if pipeline executed without failure
280 */
281 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
282 constexpr uint32_t dataPsnTxtInfStartX = 20;
283 constexpr uint32_t dataPsnTxtInfStartY = 40;
284
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100285 auto& profiler = ctx.Get<Profiler&>("profiler");
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100286 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000287
288 /* Get model reference. */
289 auto& asrModel = ctx.Get<Model&>("asrmodel");
290 if (!asrModel.IsInited()) {
291 printf_err("ASR model has not been initialised\n");
292 return false;
293 }
294
295 /* Get score threshold to be applied for the classifier (post-inference). */
296 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
297
298 /* Dimensions of the tensor should have been verified by the callee. */
299 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
300 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
301 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
302
303 /* Populate ASR MFCC related parameters. */
304 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
305 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
306
307 /* Populate ASR inference context and inner lengths for input. */
308 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
309 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
310
311 /* Make sure the input tensor supports the above context and inner lengths. */
312 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100313 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
314 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000315 return false;
316 }
317
318 /* Audio data stride corresponds to inputInnerLen feature vectors. */
319 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
320 asrMfccParamsWinStride + (asrMfccParamsWinLen);
321 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
322 const float asrAudioParamsSecondsPerSample =
323 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
324
325 /* Get pre/post-processing objects */
326 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
327 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
328
329 /* Set default reduction axis for post-processing. */
330 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
331
332 /* Get the remaining audio buffer and respective size from KWS results. */
333 const int16_t* audioArr = kwsOutput.asrAudioStart;
334 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
335
336 /* Audio clip must have enough samples to produce 1 MFCC feature. */
337 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
338 if (audioArrSize < asrMfccParamsWinLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100339 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
340 asrMfccParamsWinLen);
alexander3c798932021-03-26 21:42:19 +0000341 return false;
342 }
343
344 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100345 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000346 audioBuffer.data(),
347 audioBuffer.size(),
348 asrAudioParamsWinLen,
349 asrAudioParamsWinStride);
350
351 /* Declare a container for results. */
352 std::vector<arm::app::asr::AsrResult> asrResults;
353
354 /* Display message on the LCD - inference running. */
355 std::string str_inf{"Running ASR inference... "};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100356 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000357 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100358 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000359
360 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
361
362 /* Start sliding through audio clip. */
363 while (audioDataSlider.HasNext()) {
364
365 /* If not enough audio see how much can be sent for processing. */
366 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
367 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
368 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
369 }
370
371 const int16_t* asrInferenceWindow = audioDataSlider.Next();
372
373 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
374 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
375
alexander3c798932021-03-26 21:42:19 +0000376 /* Calculate MFCCs, deltas and populate the input tensor. */
377 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
378
alexander3c798932021-03-26 21:42:19 +0000379 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100380 if (!RunInference(asrModel, profiler)) {
381 printf_err("ASR inference failed\n");
382 return false;
383 }
alexander3c798932021-03-26 21:42:19 +0000384
385 /* Post-process. */
386 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
387
388 /* Get results. */
389 std::vector<ClassificationResult> asrClassificationResult;
390 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
391 asrClassifier.GetClassificationResults(
392 asrOutputTensor, asrClassificationResult,
393 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
394
395 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
396 (audioDataSlider.Index() *
397 asrAudioParamsSecondsPerSample *
398 asrAudioParamsWinStride),
399 audioDataSlider.Index(), asrScoreThreshold));
400
401#if VERIFY_TEST_OUTPUT
402 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
403#endif /* VERIFY_TEST_OUTPUT */
404
405 /* Erase */
406 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100407 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000408 str_inf.c_str(), str_inf.size(),
409 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
410 }
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100411 if (!PresentInferenceResult(asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000412 return false;
413 }
414
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100415 profiler.PrintProfilingResult();
416
alexander3c798932021-03-26 21:42:19 +0000417 return true;
418 }
419
420 /* Audio inference classification handler. */
421 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
422 {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100423 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000424
425 /* If the request has a valid size, set the audio index. */
426 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100427 if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) {
alexander3c798932021-03-26 21:42:19 +0000428 return false;
429 }
430 }
431
432 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
433
434 do {
435 KWSOutput kwsOutput = doKws(ctx);
436 if (!kwsOutput.executionSuccess) {
437 return false;
438 }
439
440 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
441 info("Keyword spotted\n");
442 if(!doAsr(ctx, kwsOutput)) {
443 printf_err("ASR failed");
444 return false;
445 }
446 }
447
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100448 IncrementAppCtxIfmIdx(ctx,"kws_asr");
alexander3c798932021-03-26 21:42:19 +0000449
450 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
451
452 return true;
453 }
454
alexander3c798932021-03-26 21:42:19 +0000455
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100456 static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000457 {
458 constexpr uint32_t dataPsnTxtStartX1 = 20;
459 constexpr uint32_t dataPsnTxtStartY1 = 30;
460 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
461
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100462 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000463
464 /* Display each result. */
465 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
466
467 for (uint32_t i = 0; i < results.size(); ++i) {
468
469 std::string topKeyword{"<none>"};
470 float score = 0.f;
471
alexanderc350cdc2021-04-29 20:36:09 +0100472 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000473 topKeyword = results[i].m_resultVec[0].m_label;
474 score = results[i].m_resultVec[0].m_normalisedVal;
475 }
476
477 std::string resultStr =
478 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
479 std::string{"s: "} + topKeyword + std::string{" ("} +
480 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
481
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100482 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000483 resultStr.c_str(), resultStr.size(),
484 dataPsnTxtStartX1, rowIdx1, 0);
485 rowIdx1 += dataPsnTxtYIncr;
486
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100487 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
alexander3c798932021-03-26 21:42:19 +0000488 results[i].m_timeStamp, results[i].m_inferenceNumber,
489 results[i].m_threshold);
490 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100491 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
alexander3c798932021-03-26 21:42:19 +0000492 results[i].m_resultVec[j].m_label.c_str(),
493 results[i].m_resultVec[j].m_normalisedVal);
494 }
495 }
496
497 return true;
498 }
499
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100500 static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000501 {
502 constexpr uint32_t dataPsnTxtStartX1 = 20;
503 constexpr uint32_t dataPsnTxtStartY1 = 80;
504 constexpr bool allow_multiple_lines = true;
505
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100506 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000507
508 /* Results from multiple inferences should be combined before processing. */
509 std::vector<arm::app::ClassificationResult> combinedResults;
510 for (auto& result : results) {
511 combinedResults.insert(combinedResults.end(),
512 result.m_resultVec.begin(),
513 result.m_resultVec.end());
514 }
515
516 for (auto& result : results) {
517 /* Get the final result string using the decoder. */
518 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
519
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100520 info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber,
alexander3c798932021-03-26 21:42:19 +0000521 infResultStr.c_str());
522 }
523
524 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
525
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100526 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000527 finalResultStr.c_str(), finalResultStr.size(),
528 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
529
530 info("Final result: %s\n", finalResultStr.c_str());
531 return true;
532 }
533
534 /**
535 * @brief Generic feature calculator factory.
536 *
537 * Returns lambda function to compute features using features cache.
538 * Real features math is done by a lambda function provided as a parameter.
539 * Features are written to input tensor memory.
540 *
541 * @tparam T feature vector type.
542 * @param inputTensor model input tensor pointer.
543 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
544 * @param compute features calculator function.
545 * @return lambda function to compute features.
546 **/
547 template<class T>
548 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100549 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
550 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000551 {
552 /* Feature cache to be captured by lambda function. */
553 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
554
555 return [=](std::vector<int16_t>& audioDataWindow,
556 size_t index,
557 bool useCache,
558 size_t featuresOverlapIndex)
559 {
560 T* tensorData = tflite::GetTensorData<T>(inputTensor);
561 std::vector<T> features;
562
563 /* Reuse features from cache if cache is ready and sliding windows overlap.
564 * Overlap is in the beginning of sliding window with a size of a feature cache.
565 */
566 if (useCache && index < featureCache.size()) {
567 features = std::move(featureCache[index]);
568 } else {
569 features = std::move(compute(audioDataWindow));
570 }
571 auto size = features.size();
572 auto sizeBytes = sizeof(T) * size;
573 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
574
575 /* Start renewing cache as soon iteration goes out of the windows overlap. */
576 if (index >= featuresOverlapIndex) {
577 featureCache[index - featuresOverlapIndex] = std::move(features);
578 }
579 };
580 }
581
582 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100583 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
584 size_t cacheSize,
585 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
586
587 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
588 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000589 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100590 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000591
592 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100593 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
594 size_t cacheSize,
595 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000596
597 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100598 FeatureCalc<float>(TfLiteTensor* inputTensor,
599 size_t cacheSize,
600 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000601
602
603 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000604 GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
alexander3c798932021-03-26 21:42:19 +0000605 {
606 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
607
608 TfLiteQuantization quant = inputTensor->quantization;
609
610 if (kTfLiteAffineQuantization == quant.type) {
611
612 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
613 const float quantScale = quantParams->scale->data[0];
614 const int quantOffset = quantParams->zero_point->data[0];
615
616 switch (inputTensor->type) {
617 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100618 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
619 cacheSize,
620 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
621 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
622 quantScale,
623 quantOffset);
624 }
alexander3c798932021-03-26 21:42:19 +0000625 );
626 break;
627 }
628 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100629 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
630 cacheSize,
631 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
632 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
633 quantScale,
634 quantOffset);
635 }
alexander3c798932021-03-26 21:42:19 +0000636 );
637 break;
638 }
639 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100640 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
641 cacheSize,
642 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
643 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
644 quantScale,
645 quantOffset);
646 }
alexander3c798932021-03-26 21:42:19 +0000647 );
648 break;
649 }
650 default:
651 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
652 }
653
654
655 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100656 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
657 cacheSize,
658 [&mfcc](std::vector<int16_t>& audioDataWindow) {
659 return mfcc.MfccCompute(audioDataWindow);
660 });
alexander3c798932021-03-26 21:42:19 +0000661 }
662 return mfccFeatureCalc;
663 }
664} /* namespace app */
665} /* namespace arm */