blob: 1d88ba1ca9b8e29b4a5245c8b9aa7b3f9bc80c38 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "DsCnnModel.hpp"
24#include "DsCnnMfcc.hpp"
25#include "Classifier.hpp"
26#include "KwsResult.hpp"
27#include "Wav2LetterMfcc.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "Wav2LetterPostprocess.hpp"
30#include "AsrResult.hpp"
31#include "AsrClassifier.hpp"
32#include "OutputDecode.hpp"
33
34
35using KwsClassifier = arm::app::Classifier;
36
37namespace arm {
38namespace app {
39
40 enum AsrOutputReductionAxis {
41 AxisRow = 1,
42 AxisCol = 2
43 };
44
45 struct KWSOutput {
46 bool executionSuccess = false;
47 const int16_t* asrAudioStart = nullptr;
48 int32_t asrAudioSamples = 0;
49 };
50
51 /**
alexander3c798932021-03-26 21:42:19 +000052 * @brief Presents kws inference results using the data presentation
53 * object.
54 * @param[in] platform reference to the hal platform object
55 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000056 * @return true if successful, false otherwise
57 **/
alexanderc350cdc2021-04-29 20:36:09 +010058 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000059
60 /**
61 * @brief Presents asr inference results using the data presentation
62 * object.
63 * @param[in] platform reference to the hal platform object
64 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000065 * @return true if successful, false otherwise
66 **/
alexanderc350cdc2021-04-29 20:36:09 +010067 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000068
69 /**
70 * @brief Returns a function to perform feature calculation and populates input tensor data with
71 * MFCC data.
72 *
73 * Input tensor data type check is performed to choose correct MFCC feature data type.
74 * If tensor has an integer data type then original features are quantised.
75 *
76 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
77 *
78 * @param[in] mfcc MFCC feature calculator.
79 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
80 * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
81 *
82 * @return function function to be called providing audio sample and sliding window index.
83 **/
84 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
85 GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
86 TfLiteTensor* inputTensor,
87 size_t cacheSize);
88
89 /**
90 * @brief Performs the KWS pipeline.
91 * @param[in,out] ctx pointer to the application context object
92 *
93 * @return KWSOutput struct containing pointer to audio data where ASR should begin
94 * and how much data to process.
95 */
96 static KWSOutput doKws(ApplicationContext& ctx) {
97 constexpr uint32_t dataPsnTxtInfStartX = 20;
98 constexpr uint32_t dataPsnTxtInfStartY = 40;
99
100 constexpr int minTensorDims = static_cast<int>(
101 (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
102 arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
103
104 KWSOutput output;
105
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100106 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000107 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
108 if (!kwsModel.IsInited()) {
109 printf_err("KWS model has not been initialised\n");
110 return output;
111 }
112
113 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
114 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
115 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
116
117 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
118 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
119
120 if (!kwsInputTensor->dims) {
121 printf_err("Invalid input tensor dims\n");
122 return output;
123 } else if (kwsInputTensor->dims->size < minTensorDims) {
124 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
125 return output;
126 }
127
128 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
129 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
130
131 audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength);
132 kwsMfcc.Init();
133
134 /* Deduce the data length required for 1 KWS inference from the network parameters. */
135 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
136 (kwsFrameLength - kwsFrameStride);
137 auto kwsMfccWindowSize = kwsFrameLength;
138 auto kwsMfccWindowStride = kwsFrameStride;
139
140 /* We are choosing to move by half the window size => for a 1 second window size,
141 * this means an overlap of 0.5 seconds. */
142 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
143
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100144 info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000145
146 /* Stride must be multiple of mfcc features window stride to re-use features. */
147 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
148 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
149 }
150
151 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
152
153 /* We expect to be sampling 1 second worth of data at a time
154 * NOTE: This is only used for time stamp calculation. */
155 const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
156
157 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
158
159 /* Creating a mfcc features sliding window for the data required for 1 inference. */
160 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
161 get_audio_array(currentIndex),
162 kwsAudioDataWindowSize, kwsMfccWindowSize,
163 kwsMfccWindowStride);
164
165 /* Creating a sliding window through the whole audio clip. */
166 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
167 get_audio_array(currentIndex),
168 get_audio_array_size(currentIndex),
169 kwsAudioDataWindowSize, kwsAudioDataStride);
170
171 /* Calculate number of the feature vectors in the window overlap region.
172 * These feature vectors will be reused.*/
173 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
174 - kwsMfccVectorsInAudioStride;
175
176 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
177 numberOfReusedFeatureVectors);
178
179 if (!kwsMfccFeatureCalc){
180 return output;
181 }
182
183 /* Container for KWS results. */
184 std::vector<arm::app::kws::KwsResult> kwsResults;
185
186 /* Display message on the LCD - inference running. */
187 auto& platform = ctx.Get<hal_platform&>("platform");
188 std::string str_inf{"Running KWS inference... "};
189 platform.data_psn->present_data_text(
190 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100191 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000192
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100193 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
alexander3c798932021-03-26 21:42:19 +0000194 currentIndex, get_filename(currentIndex));
195
196 /* Start sliding through audio clip. */
197 while (audioDataSlider.HasNext()) {
198 const int16_t* inferenceWindow = audioDataSlider.Next();
199
200 /* We moved to the next window - set the features sliding to the new address. */
201 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
202
203 /* The first window does not have cache ready. */
204 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
205
206 /* Start calculating features inside one audio sliding window. */
207 while (kwsAudioMFCCWindowSlider.HasNext()) {
208 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
209 std::vector<int16_t> kwsMfccAudioData =
210 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
211
212 /* Compute features for this window and write them to input tensor. */
213 kwsMfccFeatureCalc(kwsMfccAudioData,
214 kwsAudioMFCCWindowSlider.Index(),
215 useCache,
216 kwsMfccVectorsInAudioStride);
217 }
218
219 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
220 audioDataSlider.TotalStrides() + 1);
221
222 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100223 if (!RunInference(kwsModel, profiler)) {
224 printf_err("KWS inference failed\n");
225 return output;
226 }
alexander3c798932021-03-26 21:42:19 +0000227
228 std::vector<ClassificationResult> kwsClassificationResult;
229 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
230
231 kwsClassifier.GetClassificationResults(
232 kwsOutputTensor, kwsClassificationResult,
233 ctx.Get<std::vector<std::string>&>("kwslabels"), 1);
234
235 kwsResults.emplace_back(
236 kws::KwsResult(
237 kwsClassificationResult,
238 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
239 audioDataSlider.Index(), kwsScoreThreshold)
240 );
241
242 /* Keyword detected. */
243 if (kwsClassificationResult[0].m_labelIdx == ctx.Get<uint32_t>("keywordindex")) {
244 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
245 output.asrAudioSamples = get_audio_array_size(currentIndex) -
246 (audioDataSlider.NextWindowStartIndex() -
247 kwsAudioDataStride + kwsAudioDataWindowSize);
248 break;
249 }
250
251#if VERIFY_TEST_OUTPUT
252 arm::app::DumpTensor(kwsOutputTensor);
253#endif /* VERIFY_TEST_OUTPUT */
254
255 } /* while (audioDataSlider.HasNext()) */
256
257 /* Erase. */
258 str_inf = std::string(str_inf.size(), ' ');
259 platform.data_psn->present_data_text(
260 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100261 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000262
alexanderc350cdc2021-04-29 20:36:09 +0100263 if (!PresentInferenceResult(platform, kwsResults)) {
alexander3c798932021-03-26 21:42:19 +0000264 return output;
265 }
266
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100267 profiler.PrintProfilingResult();
268
alexander3c798932021-03-26 21:42:19 +0000269 output.executionSuccess = true;
270 return output;
271 }
272
273 /**
274 * @brief Performs the ASR pipeline.
275 *
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100276 * @param[in,out] ctx pointer to the application context object
277 * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin
alexander3c798932021-03-26 21:42:19 +0000278 * and how much data to process
279 * @return bool true if pipeline executed without failure
280 */
281 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
282 constexpr uint32_t dataPsnTxtInfStartX = 20;
283 constexpr uint32_t dataPsnTxtInfStartY = 40;
284
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100285 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000286 auto& platform = ctx.Get<hal_platform&>("platform");
287 platform.data_psn->clear(COLOR_BLACK);
288
289 /* Get model reference. */
290 auto& asrModel = ctx.Get<Model&>("asrmodel");
291 if (!asrModel.IsInited()) {
292 printf_err("ASR model has not been initialised\n");
293 return false;
294 }
295
296 /* Get score threshold to be applied for the classifier (post-inference). */
297 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
298
299 /* Dimensions of the tensor should have been verified by the callee. */
300 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
301 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
302 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
303
304 /* Populate ASR MFCC related parameters. */
305 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
306 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
307
308 /* Populate ASR inference context and inner lengths for input. */
309 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
310 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
311
312 /* Make sure the input tensor supports the above context and inner lengths. */
313 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100314 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
315 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000316 return false;
317 }
318
319 /* Audio data stride corresponds to inputInnerLen feature vectors. */
320 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
321 asrMfccParamsWinStride + (asrMfccParamsWinLen);
322 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
323 const float asrAudioParamsSecondsPerSample =
324 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
325
326 /* Get pre/post-processing objects */
327 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
328 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
329
330 /* Set default reduction axis for post-processing. */
331 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
332
333 /* Get the remaining audio buffer and respective size from KWS results. */
334 const int16_t* audioArr = kwsOutput.asrAudioStart;
335 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
336
337 /* Audio clip must have enough samples to produce 1 MFCC feature. */
338 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
339 if (audioArrSize < asrMfccParamsWinLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100340 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
341 asrMfccParamsWinLen);
alexander3c798932021-03-26 21:42:19 +0000342 return false;
343 }
344
345 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100346 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000347 audioBuffer.data(),
348 audioBuffer.size(),
349 asrAudioParamsWinLen,
350 asrAudioParamsWinStride);
351
352 /* Declare a container for results. */
353 std::vector<arm::app::asr::AsrResult> asrResults;
354
355 /* Display message on the LCD - inference running. */
356 std::string str_inf{"Running ASR inference... "};
357 platform.data_psn->present_data_text(
358 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100359 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000360
361 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
362
363 /* Start sliding through audio clip. */
364 while (audioDataSlider.HasNext()) {
365
366 /* If not enough audio see how much can be sent for processing. */
367 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
368 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
369 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
370 }
371
372 const int16_t* asrInferenceWindow = audioDataSlider.Next();
373
374 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
375 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
376
alexander3c798932021-03-26 21:42:19 +0000377 /* Calculate MFCCs, deltas and populate the input tensor. */
378 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
379
alexander3c798932021-03-26 21:42:19 +0000380 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100381 if (!RunInference(asrModel, profiler)) {
382 printf_err("ASR inference failed\n");
383 return false;
384 }
alexander3c798932021-03-26 21:42:19 +0000385
386 /* Post-process. */
387 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
388
389 /* Get results. */
390 std::vector<ClassificationResult> asrClassificationResult;
391 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
392 asrClassifier.GetClassificationResults(
393 asrOutputTensor, asrClassificationResult,
394 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
395
396 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
397 (audioDataSlider.Index() *
398 asrAudioParamsSecondsPerSample *
399 asrAudioParamsWinStride),
400 audioDataSlider.Index(), asrScoreThreshold));
401
402#if VERIFY_TEST_OUTPUT
403 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
404#endif /* VERIFY_TEST_OUTPUT */
405
406 /* Erase */
407 str_inf = std::string(str_inf.size(), ' ');
408 platform.data_psn->present_data_text(
409 str_inf.c_str(), str_inf.size(),
410 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
411 }
alexanderc350cdc2021-04-29 20:36:09 +0100412 if (!PresentInferenceResult(platform, asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000413 return false;
414 }
415
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100416 profiler.PrintProfilingResult();
417
alexander3c798932021-03-26 21:42:19 +0000418 return true;
419 }
420
421 /* Audio inference classification handler. */
422 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
423 {
424 auto& platform = ctx.Get<hal_platform&>("platform");
425 platform.data_psn->clear(COLOR_BLACK);
426
427 /* If the request has a valid size, set the audio index. */
428 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100429 if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) {
alexander3c798932021-03-26 21:42:19 +0000430 return false;
431 }
432 }
433
434 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
435
436 do {
437 KWSOutput kwsOutput = doKws(ctx);
438 if (!kwsOutput.executionSuccess) {
439 return false;
440 }
441
442 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
443 info("Keyword spotted\n");
444 if(!doAsr(ctx, kwsOutput)) {
445 printf_err("ASR failed");
446 return false;
447 }
448 }
449
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100450 IncrementAppCtxIfmIdx(ctx,"kws_asr");
alexander3c798932021-03-26 21:42:19 +0000451
452 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
453
454 return true;
455 }
456
alexander3c798932021-03-26 21:42:19 +0000457
alexanderc350cdc2021-04-29 20:36:09 +0100458 static bool PresentInferenceResult(hal_platform& platform,
459 std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000460 {
461 constexpr uint32_t dataPsnTxtStartX1 = 20;
462 constexpr uint32_t dataPsnTxtStartY1 = 30;
463 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
464
465 platform.data_psn->set_text_color(COLOR_GREEN);
466
467 /* Display each result. */
468 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
469
470 for (uint32_t i = 0; i < results.size(); ++i) {
471
472 std::string topKeyword{"<none>"};
473 float score = 0.f;
474
alexanderc350cdc2021-04-29 20:36:09 +0100475 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000476 topKeyword = results[i].m_resultVec[0].m_label;
477 score = results[i].m_resultVec[0].m_normalisedVal;
478 }
479
480 std::string resultStr =
481 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
482 std::string{"s: "} + topKeyword + std::string{" ("} +
483 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
484
485 platform.data_psn->present_data_text(
486 resultStr.c_str(), resultStr.size(),
487 dataPsnTxtStartX1, rowIdx1, 0);
488 rowIdx1 += dataPsnTxtYIncr;
489
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100490 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
alexander3c798932021-03-26 21:42:19 +0000491 results[i].m_timeStamp, results[i].m_inferenceNumber,
492 results[i].m_threshold);
493 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100494 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
alexander3c798932021-03-26 21:42:19 +0000495 results[i].m_resultVec[j].m_label.c_str(),
496 results[i].m_resultVec[j].m_normalisedVal);
497 }
498 }
499
500 return true;
501 }
502
alexanderc350cdc2021-04-29 20:36:09 +0100503 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000504 {
505 constexpr uint32_t dataPsnTxtStartX1 = 20;
506 constexpr uint32_t dataPsnTxtStartY1 = 80;
507 constexpr bool allow_multiple_lines = true;
508
509 platform.data_psn->set_text_color(COLOR_GREEN);
510
511 /* Results from multiple inferences should be combined before processing. */
512 std::vector<arm::app::ClassificationResult> combinedResults;
513 for (auto& result : results) {
514 combinedResults.insert(combinedResults.end(),
515 result.m_resultVec.begin(),
516 result.m_resultVec.end());
517 }
518
519 for (auto& result : results) {
520 /* Get the final result string using the decoder. */
521 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
522
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100523 info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber,
alexander3c798932021-03-26 21:42:19 +0000524 infResultStr.c_str());
525 }
526
527 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
528
529 platform.data_psn->present_data_text(
530 finalResultStr.c_str(), finalResultStr.size(),
531 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
532
533 info("Final result: %s\n", finalResultStr.c_str());
534 return true;
535 }
536
537 /**
538 * @brief Generic feature calculator factory.
539 *
540 * Returns lambda function to compute features using features cache.
541 * Real features math is done by a lambda function provided as a parameter.
542 * Features are written to input tensor memory.
543 *
544 * @tparam T feature vector type.
545 * @param inputTensor model input tensor pointer.
546 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
547 * @param compute features calculator function.
548 * @return lambda function to compute features.
549 **/
550 template<class T>
551 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100552 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
553 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000554 {
555 /* Feature cache to be captured by lambda function. */
556 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
557
558 return [=](std::vector<int16_t>& audioDataWindow,
559 size_t index,
560 bool useCache,
561 size_t featuresOverlapIndex)
562 {
563 T* tensorData = tflite::GetTensorData<T>(inputTensor);
564 std::vector<T> features;
565
566 /* Reuse features from cache if cache is ready and sliding windows overlap.
567 * Overlap is in the beginning of sliding window with a size of a feature cache.
568 */
569 if (useCache && index < featureCache.size()) {
570 features = std::move(featureCache[index]);
571 } else {
572 features = std::move(compute(audioDataWindow));
573 }
574 auto size = features.size();
575 auto sizeBytes = sizeof(T) * size;
576 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
577
578 /* Start renewing cache as soon iteration goes out of the windows overlap. */
579 if (index >= featuresOverlapIndex) {
580 featureCache[index - featuresOverlapIndex] = std::move(features);
581 }
582 };
583 }
584
585 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100586 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
587 size_t cacheSize,
588 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
589
590 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
591 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000592 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100593 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000594
595 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100596 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
597 size_t cacheSize,
598 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000599
600 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100601 FeatureCalc<float>(TfLiteTensor* inputTensor,
602 size_t cacheSize,
603 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000604
605
606 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
607 GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
608 {
609 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
610
611 TfLiteQuantization quant = inputTensor->quantization;
612
613 if (kTfLiteAffineQuantization == quant.type) {
614
615 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
616 const float quantScale = quantParams->scale->data[0];
617 const int quantOffset = quantParams->zero_point->data[0];
618
619 switch (inputTensor->type) {
620 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100621 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
622 cacheSize,
623 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
624 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
625 quantScale,
626 quantOffset);
627 }
alexander3c798932021-03-26 21:42:19 +0000628 );
629 break;
630 }
631 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100632 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
633 cacheSize,
634 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
635 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
636 quantScale,
637 quantOffset);
638 }
alexander3c798932021-03-26 21:42:19 +0000639 );
640 break;
641 }
642 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100643 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
644 cacheSize,
645 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
646 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
647 quantScale,
648 quantOffset);
649 }
alexander3c798932021-03-26 21:42:19 +0000650 );
651 break;
652 }
653 default:
654 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
655 }
656
657
658 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100659 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
660 cacheSize,
661 [&mfcc](std::vector<int16_t>& audioDataWindow) {
662 return mfcc.MfccCompute(audioDataWindow);
663 });
alexander3c798932021-03-26 21:42:19 +0000664 }
665 return mfccFeatureCalc;
666 }
667} /* namespace app */
668} /* namespace arm */