blob: bfc1d25239704ef6951cb47ac0accc952abae319 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
22#include "UseCaseCommonUtils.hpp"
Kshitij Sisodia76a15802021-12-24 11:05:11 +000023#include "MicroNetKwsModel.hpp"
24#include "MicroNetKwsMfcc.hpp"
alexander3c798932021-03-26 21:42:19 +000025#include "Classifier.hpp"
26#include "KwsResult.hpp"
27#include "Wav2LetterMfcc.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "Wav2LetterPostprocess.hpp"
30#include "AsrResult.hpp"
31#include "AsrClassifier.hpp"
32#include "OutputDecode.hpp"
alexander31ae9f02022-02-10 16:15:54 +000033#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000034
35
36using KwsClassifier = arm::app::Classifier;
37
38namespace arm {
39namespace app {
40
41 enum AsrOutputReductionAxis {
42 AxisRow = 1,
43 AxisCol = 2
44 };
45
46 struct KWSOutput {
47 bool executionSuccess = false;
48 const int16_t* asrAudioStart = nullptr;
49 int32_t asrAudioSamples = 0;
50 };
51
52 /**
alexander3c798932021-03-26 21:42:19 +000053 * @brief Presents kws inference results using the data presentation
54 * object.
55 * @param[in] platform reference to the hal platform object
56 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000057 * @return true if successful, false otherwise
58 **/
alexanderc350cdc2021-04-29 20:36:09 +010059 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000060
61 /**
62 * @brief Presents asr inference results using the data presentation
63 * object.
64 * @param[in] platform reference to the hal platform object
65 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000066 * @return true if successful, false otherwise
67 **/
alexanderc350cdc2021-04-29 20:36:09 +010068 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000069
70 /**
71 * @brief Returns a function to perform feature calculation and populates input tensor data with
72 * MFCC data.
73 *
74 * Input tensor data type check is performed to choose correct MFCC feature data type.
75 * If tensor has an integer data type then original features are quantised.
76 *
77 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
78 *
79 * @param[in] mfcc MFCC feature calculator.
80 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
Kshitij Sisodia76a15802021-12-24 11:05:11 +000081 * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
alexander3c798932021-03-26 21:42:19 +000082 *
83 * @return function function to be called providing audio sample and sliding window index.
84 **/
85 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
Kshitij Sisodia76a15802021-12-24 11:05:11 +000086 GetFeatureCalculator(audio::MicroNetMFCC& mfcc,
alexander3c798932021-03-26 21:42:19 +000087 TfLiteTensor* inputTensor,
88 size_t cacheSize);
89
90 /**
91 * @brief Performs the KWS pipeline.
92 * @param[in,out] ctx pointer to the application context object
93 *
94 * @return KWSOutput struct containing pointer to audio data where ASR should begin
95 * and how much data to process.
96 */
97 static KWSOutput doKws(ApplicationContext& ctx) {
98 constexpr uint32_t dataPsnTxtInfStartX = 20;
99 constexpr uint32_t dataPsnTxtInfStartY = 40;
100
101 constexpr int minTensorDims = static_cast<int>(
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000102 (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
103 arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
alexander3c798932021-03-26 21:42:19 +0000104
105 KWSOutput output;
106
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100107 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000108 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
109 if (!kwsModel.IsInited()) {
110 printf_err("KWS model has not been initialised\n");
111 return output;
112 }
113
114 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
115 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
116 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
117
118 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
119 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
120
121 if (!kwsInputTensor->dims) {
122 printf_err("Invalid input tensor dims\n");
123 return output;
124 } else if (kwsInputTensor->dims->size < minTensorDims) {
125 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
126 return output;
127 }
128
129 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
130 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
131
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000132 audio::MicroNetMFCC kwsMfcc = audio::MicroNetMFCC(kwsNumMfccFeats, kwsFrameLength);
alexander3c798932021-03-26 21:42:19 +0000133 kwsMfcc.Init();
134
135 /* Deduce the data length required for 1 KWS inference from the network parameters. */
136 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
137 (kwsFrameLength - kwsFrameStride);
138 auto kwsMfccWindowSize = kwsFrameLength;
139 auto kwsMfccWindowStride = kwsFrameStride;
140
141 /* We are choosing to move by half the window size => for a 1 second window size,
142 * this means an overlap of 0.5 seconds. */
143 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
144
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100145 info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000146
147 /* Stride must be multiple of mfcc features window stride to re-use features. */
148 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
149 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
150 }
151
152 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
153
154 /* We expect to be sampling 1 second worth of data at a time
155 * NOTE: This is only used for time stamp calculation. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000156 const float kwsAudioParamsSecondsPerSample = 1.0/audio::MicroNetMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000157
158 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
159
160 /* Creating a mfcc features sliding window for the data required for 1 inference. */
161 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
162 get_audio_array(currentIndex),
163 kwsAudioDataWindowSize, kwsMfccWindowSize,
164 kwsMfccWindowStride);
165
166 /* Creating a sliding window through the whole audio clip. */
167 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
168 get_audio_array(currentIndex),
169 get_audio_array_size(currentIndex),
170 kwsAudioDataWindowSize, kwsAudioDataStride);
171
172 /* Calculate number of the feature vectors in the window overlap region.
173 * These feature vectors will be reused.*/
174 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
175 - kwsMfccVectorsInAudioStride;
176
177 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
178 numberOfReusedFeatureVectors);
179
180 if (!kwsMfccFeatureCalc){
181 return output;
182 }
183
184 /* Container for KWS results. */
185 std::vector<arm::app::kws::KwsResult> kwsResults;
186
187 /* Display message on the LCD - inference running. */
188 auto& platform = ctx.Get<hal_platform&>("platform");
189 std::string str_inf{"Running KWS inference... "};
190 platform.data_psn->present_data_text(
191 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100192 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000193
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100194 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
alexander3c798932021-03-26 21:42:19 +0000195 currentIndex, get_filename(currentIndex));
196
197 /* Start sliding through audio clip. */
198 while (audioDataSlider.HasNext()) {
199 const int16_t* inferenceWindow = audioDataSlider.Next();
200
201 /* We moved to the next window - set the features sliding to the new address. */
202 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
203
204 /* The first window does not have cache ready. */
205 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
206
207 /* Start calculating features inside one audio sliding window. */
208 while (kwsAudioMFCCWindowSlider.HasNext()) {
209 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
210 std::vector<int16_t> kwsMfccAudioData =
211 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
212
213 /* Compute features for this window and write them to input tensor. */
214 kwsMfccFeatureCalc(kwsMfccAudioData,
215 kwsAudioMFCCWindowSlider.Index(),
216 useCache,
217 kwsMfccVectorsInAudioStride);
218 }
219
220 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
221 audioDataSlider.TotalStrides() + 1);
222
223 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100224 if (!RunInference(kwsModel, profiler)) {
225 printf_err("KWS inference failed\n");
226 return output;
227 }
alexander3c798932021-03-26 21:42:19 +0000228
229 std::vector<ClassificationResult> kwsClassificationResult;
230 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
231
232 kwsClassifier.GetClassificationResults(
233 kwsOutputTensor, kwsClassificationResult,
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000234 ctx.Get<std::vector<std::string>&>("kwslabels"), 1, true);
alexander3c798932021-03-26 21:42:19 +0000235
236 kwsResults.emplace_back(
237 kws::KwsResult(
238 kwsClassificationResult,
239 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
240 audioDataSlider.Index(), kwsScoreThreshold)
241 );
242
243 /* Keyword detected. */
Liam Barryb5b32d32021-12-30 11:35:00 +0000244 if (kwsClassificationResult[0].m_label == ctx.Get<const std::string&>("triggerkeyword")) {
alexander3c798932021-03-26 21:42:19 +0000245 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
246 output.asrAudioSamples = get_audio_array_size(currentIndex) -
247 (audioDataSlider.NextWindowStartIndex() -
248 kwsAudioDataStride + kwsAudioDataWindowSize);
249 break;
250 }
251
252#if VERIFY_TEST_OUTPUT
253 arm::app::DumpTensor(kwsOutputTensor);
254#endif /* VERIFY_TEST_OUTPUT */
255
256 } /* while (audioDataSlider.HasNext()) */
257
258 /* Erase. */
259 str_inf = std::string(str_inf.size(), ' ');
260 platform.data_psn->present_data_text(
261 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100262 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000263
alexanderc350cdc2021-04-29 20:36:09 +0100264 if (!PresentInferenceResult(platform, kwsResults)) {
alexander3c798932021-03-26 21:42:19 +0000265 return output;
266 }
267
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100268 profiler.PrintProfilingResult();
269
alexander3c798932021-03-26 21:42:19 +0000270 output.executionSuccess = true;
271 return output;
272 }
273
274 /**
275 * @brief Performs the ASR pipeline.
276 *
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100277 * @param[in,out] ctx pointer to the application context object
278 * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin
alexander3c798932021-03-26 21:42:19 +0000279 * and how much data to process
280 * @return bool true if pipeline executed without failure
281 */
282 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
283 constexpr uint32_t dataPsnTxtInfStartX = 20;
284 constexpr uint32_t dataPsnTxtInfStartY = 40;
285
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100286 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000287 auto& platform = ctx.Get<hal_platform&>("platform");
288 platform.data_psn->clear(COLOR_BLACK);
289
290 /* Get model reference. */
291 auto& asrModel = ctx.Get<Model&>("asrmodel");
292 if (!asrModel.IsInited()) {
293 printf_err("ASR model has not been initialised\n");
294 return false;
295 }
296
297 /* Get score threshold to be applied for the classifier (post-inference). */
298 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
299
300 /* Dimensions of the tensor should have been verified by the callee. */
301 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
302 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
303 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
304
305 /* Populate ASR MFCC related parameters. */
306 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
307 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
308
309 /* Populate ASR inference context and inner lengths for input. */
310 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
311 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
312
313 /* Make sure the input tensor supports the above context and inner lengths. */
314 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100315 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
316 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000317 return false;
318 }
319
320 /* Audio data stride corresponds to inputInnerLen feature vectors. */
321 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
322 asrMfccParamsWinStride + (asrMfccParamsWinLen);
323 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
324 const float asrAudioParamsSecondsPerSample =
325 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
326
327 /* Get pre/post-processing objects */
328 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
329 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
330
331 /* Set default reduction axis for post-processing. */
332 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
333
334 /* Get the remaining audio buffer and respective size from KWS results. */
335 const int16_t* audioArr = kwsOutput.asrAudioStart;
336 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
337
338 /* Audio clip must have enough samples to produce 1 MFCC feature. */
339 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
340 if (audioArrSize < asrMfccParamsWinLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100341 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
342 asrMfccParamsWinLen);
alexander3c798932021-03-26 21:42:19 +0000343 return false;
344 }
345
346 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100347 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000348 audioBuffer.data(),
349 audioBuffer.size(),
350 asrAudioParamsWinLen,
351 asrAudioParamsWinStride);
352
353 /* Declare a container for results. */
354 std::vector<arm::app::asr::AsrResult> asrResults;
355
356 /* Display message on the LCD - inference running. */
357 std::string str_inf{"Running ASR inference... "};
358 platform.data_psn->present_data_text(
359 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100360 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000361
362 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
363
364 /* Start sliding through audio clip. */
365 while (audioDataSlider.HasNext()) {
366
367 /* If not enough audio see how much can be sent for processing. */
368 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
369 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
370 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
371 }
372
373 const int16_t* asrInferenceWindow = audioDataSlider.Next();
374
375 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
376 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
377
alexander3c798932021-03-26 21:42:19 +0000378 /* Calculate MFCCs, deltas and populate the input tensor. */
379 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
380
alexander3c798932021-03-26 21:42:19 +0000381 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100382 if (!RunInference(asrModel, profiler)) {
383 printf_err("ASR inference failed\n");
384 return false;
385 }
alexander3c798932021-03-26 21:42:19 +0000386
387 /* Post-process. */
388 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
389
390 /* Get results. */
391 std::vector<ClassificationResult> asrClassificationResult;
392 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
393 asrClassifier.GetClassificationResults(
394 asrOutputTensor, asrClassificationResult,
395 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
396
397 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
398 (audioDataSlider.Index() *
399 asrAudioParamsSecondsPerSample *
400 asrAudioParamsWinStride),
401 audioDataSlider.Index(), asrScoreThreshold));
402
403#if VERIFY_TEST_OUTPUT
404 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
405#endif /* VERIFY_TEST_OUTPUT */
406
407 /* Erase */
408 str_inf = std::string(str_inf.size(), ' ');
409 platform.data_psn->present_data_text(
410 str_inf.c_str(), str_inf.size(),
411 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
412 }
alexanderc350cdc2021-04-29 20:36:09 +0100413 if (!PresentInferenceResult(platform, asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000414 return false;
415 }
416
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100417 profiler.PrintProfilingResult();
418
alexander3c798932021-03-26 21:42:19 +0000419 return true;
420 }
421
422 /* Audio inference classification handler. */
423 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
424 {
425 auto& platform = ctx.Get<hal_platform&>("platform");
426 platform.data_psn->clear(COLOR_BLACK);
427
428 /* If the request has a valid size, set the audio index. */
429 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100430 if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) {
alexander3c798932021-03-26 21:42:19 +0000431 return false;
432 }
433 }
434
435 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
436
437 do {
438 KWSOutput kwsOutput = doKws(ctx);
439 if (!kwsOutput.executionSuccess) {
440 return false;
441 }
442
443 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
444 info("Keyword spotted\n");
445 if(!doAsr(ctx, kwsOutput)) {
446 printf_err("ASR failed");
447 return false;
448 }
449 }
450
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100451 IncrementAppCtxIfmIdx(ctx,"kws_asr");
alexander3c798932021-03-26 21:42:19 +0000452
453 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
454
455 return true;
456 }
457
alexander3c798932021-03-26 21:42:19 +0000458
alexanderc350cdc2021-04-29 20:36:09 +0100459 static bool PresentInferenceResult(hal_platform& platform,
460 std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000461 {
462 constexpr uint32_t dataPsnTxtStartX1 = 20;
463 constexpr uint32_t dataPsnTxtStartY1 = 30;
464 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
465
466 platform.data_psn->set_text_color(COLOR_GREEN);
467
468 /* Display each result. */
469 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
470
471 for (uint32_t i = 0; i < results.size(); ++i) {
472
473 std::string topKeyword{"<none>"};
474 float score = 0.f;
475
alexanderc350cdc2021-04-29 20:36:09 +0100476 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000477 topKeyword = results[i].m_resultVec[0].m_label;
478 score = results[i].m_resultVec[0].m_normalisedVal;
479 }
480
481 std::string resultStr =
482 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
483 std::string{"s: "} + topKeyword + std::string{" ("} +
484 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
485
486 platform.data_psn->present_data_text(
487 resultStr.c_str(), resultStr.size(),
488 dataPsnTxtStartX1, rowIdx1, 0);
489 rowIdx1 += dataPsnTxtYIncr;
490
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100491 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
alexander3c798932021-03-26 21:42:19 +0000492 results[i].m_timeStamp, results[i].m_inferenceNumber,
493 results[i].m_threshold);
494 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100495 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
alexander3c798932021-03-26 21:42:19 +0000496 results[i].m_resultVec[j].m_label.c_str(),
497 results[i].m_resultVec[j].m_normalisedVal);
498 }
499 }
500
501 return true;
502 }
503
alexanderc350cdc2021-04-29 20:36:09 +0100504 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000505 {
506 constexpr uint32_t dataPsnTxtStartX1 = 20;
507 constexpr uint32_t dataPsnTxtStartY1 = 80;
508 constexpr bool allow_multiple_lines = true;
509
510 platform.data_psn->set_text_color(COLOR_GREEN);
511
512 /* Results from multiple inferences should be combined before processing. */
513 std::vector<arm::app::ClassificationResult> combinedResults;
514 for (auto& result : results) {
515 combinedResults.insert(combinedResults.end(),
516 result.m_resultVec.begin(),
517 result.m_resultVec.end());
518 }
519
520 for (auto& result : results) {
521 /* Get the final result string using the decoder. */
522 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
523
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100524 info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber,
alexander3c798932021-03-26 21:42:19 +0000525 infResultStr.c_str());
526 }
527
528 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
529
530 platform.data_psn->present_data_text(
531 finalResultStr.c_str(), finalResultStr.size(),
532 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
533
534 info("Final result: %s\n", finalResultStr.c_str());
535 return true;
536 }
537
538 /**
539 * @brief Generic feature calculator factory.
540 *
541 * Returns lambda function to compute features using features cache.
542 * Real features math is done by a lambda function provided as a parameter.
543 * Features are written to input tensor memory.
544 *
545 * @tparam T feature vector type.
546 * @param inputTensor model input tensor pointer.
547 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
548 * @param compute features calculator function.
549 * @return lambda function to compute features.
550 **/
551 template<class T>
552 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100553 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
554 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000555 {
556 /* Feature cache to be captured by lambda function. */
557 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
558
559 return [=](std::vector<int16_t>& audioDataWindow,
560 size_t index,
561 bool useCache,
562 size_t featuresOverlapIndex)
563 {
564 T* tensorData = tflite::GetTensorData<T>(inputTensor);
565 std::vector<T> features;
566
567 /* Reuse features from cache if cache is ready and sliding windows overlap.
568 * Overlap is in the beginning of sliding window with a size of a feature cache.
569 */
570 if (useCache && index < featureCache.size()) {
571 features = std::move(featureCache[index]);
572 } else {
573 features = std::move(compute(audioDataWindow));
574 }
575 auto size = features.size();
576 auto sizeBytes = sizeof(T) * size;
577 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
578
579 /* Start renewing cache as soon iteration goes out of the windows overlap. */
580 if (index >= featuresOverlapIndex) {
581 featureCache[index - featuresOverlapIndex] = std::move(features);
582 }
583 };
584 }
585
586 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100587 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
588 size_t cacheSize,
589 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
590
591 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
592 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000593 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100594 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000595
596 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100597 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
598 size_t cacheSize,
599 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000600
601 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100602 FeatureCalc<float>(TfLiteTensor* inputTensor,
603 size_t cacheSize,
604 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000605
606
607 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000608 GetFeatureCalculator(audio::MicroNetMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
alexander3c798932021-03-26 21:42:19 +0000609 {
610 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
611
612 TfLiteQuantization quant = inputTensor->quantization;
613
614 if (kTfLiteAffineQuantization == quant.type) {
615
616 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
617 const float quantScale = quantParams->scale->data[0];
618 const int quantOffset = quantParams->zero_point->data[0];
619
620 switch (inputTensor->type) {
621 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100622 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
623 cacheSize,
624 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
625 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
626 quantScale,
627 quantOffset);
628 }
alexander3c798932021-03-26 21:42:19 +0000629 );
630 break;
631 }
632 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100633 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
634 cacheSize,
635 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
636 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
637 quantScale,
638 quantOffset);
639 }
alexander3c798932021-03-26 21:42:19 +0000640 );
641 break;
642 }
643 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100644 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
645 cacheSize,
646 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
647 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
648 quantScale,
649 quantOffset);
650 }
alexander3c798932021-03-26 21:42:19 +0000651 );
652 break;
653 }
654 default:
655 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
656 }
657
658
659 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100660 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
661 cacheSize,
662 [&mfcc](std::vector<int16_t>& audioDataWindow) {
663 return mfcc.MfccCompute(audioDataWindow);
664 });
alexander3c798932021-03-26 21:42:19 +0000665 }
666 return mfccFeatureCalc;
667 }
668} /* namespace app */
669} /* namespace arm */