blob: 9080348dbe36844dfbe842e8af76bfd9176ccaec [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
22#include "UseCaseCommonUtils.hpp"
23#include "DsCnnModel.hpp"
24#include "DsCnnMfcc.hpp"
25#include "Classifier.hpp"
26#include "KwsResult.hpp"
27#include "Wav2LetterMfcc.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "Wav2LetterPostprocess.hpp"
30#include "AsrResult.hpp"
31#include "AsrClassifier.hpp"
32#include "OutputDecode.hpp"
33
34
35using KwsClassifier = arm::app::Classifier;
36
37namespace arm {
38namespace app {
39
40 enum AsrOutputReductionAxis {
41 AxisRow = 1,
42 AxisCol = 2
43 };
44
45 struct KWSOutput {
46 bool executionSuccess = false;
47 const int16_t* asrAudioStart = nullptr;
48 int32_t asrAudioSamples = 0;
49 };
50
51 /**
52 * @brief Helper function to increment current audio clip index
53 * @param[in,out] ctx pointer to the application context object
54 **/
alexanderc350cdc2021-04-29 20:36:09 +010055 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
alexander3c798932021-03-26 21:42:19 +000056
57 /**
58 * @brief Helper function to set the audio clip index
59 * @param[in,out] ctx pointer to the application context object
60 * @param[in] idx value to be set
61 * @return true if index is set, false otherwise
62 **/
alexanderc350cdc2021-04-29 20:36:09 +010063 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
alexander3c798932021-03-26 21:42:19 +000064
65 /**
66 * @brief Presents kws inference results using the data presentation
67 * object.
68 * @param[in] platform reference to the hal platform object
69 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000070 * @return true if successful, false otherwise
71 **/
alexanderc350cdc2021-04-29 20:36:09 +010072 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000073
74 /**
75 * @brief Presents asr inference results using the data presentation
76 * object.
77 * @param[in] platform reference to the hal platform object
78 * @param[in] results vector of classification results to be displayed
alexander3c798932021-03-26 21:42:19 +000079 * @return true if successful, false otherwise
80 **/
alexanderc350cdc2021-04-29 20:36:09 +010081 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000082
83 /**
84 * @brief Returns a function to perform feature calculation and populates input tensor data with
85 * MFCC data.
86 *
87 * Input tensor data type check is performed to choose correct MFCC feature data type.
88 * If tensor has an integer data type then original features are quantised.
89 *
90 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
91 *
92 * @param[in] mfcc MFCC feature calculator.
93 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
94 * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
95 *
96 * @return function function to be called providing audio sample and sliding window index.
97 **/
98 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
99 GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
100 TfLiteTensor* inputTensor,
101 size_t cacheSize);
102
103 /**
104 * @brief Performs the KWS pipeline.
105 * @param[in,out] ctx pointer to the application context object
106 *
107 * @return KWSOutput struct containing pointer to audio data where ASR should begin
108 * and how much data to process.
109 */
110 static KWSOutput doKws(ApplicationContext& ctx) {
111 constexpr uint32_t dataPsnTxtInfStartX = 20;
112 constexpr uint32_t dataPsnTxtInfStartY = 40;
113
114 constexpr int minTensorDims = static_cast<int>(
115 (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
116 arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
117
118 KWSOutput output;
119
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100120 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000121 auto& kwsModel = ctx.Get<Model&>("kwsmodel");
122 if (!kwsModel.IsInited()) {
123 printf_err("KWS model has not been initialised\n");
124 return output;
125 }
126
127 const int kwsFrameLength = ctx.Get<int>("kwsframeLength");
128 const int kwsFrameStride = ctx.Get<int>("kwsframeStride");
129 const float kwsScoreThreshold = ctx.Get<float>("kwsscoreThreshold");
130
131 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
132 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
133
134 if (!kwsInputTensor->dims) {
135 printf_err("Invalid input tensor dims\n");
136 return output;
137 } else if (kwsInputTensor->dims->size < minTensorDims) {
138 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
139 return output;
140 }
141
142 const uint32_t kwsNumMfccFeats = ctx.Get<uint32_t>("kwsNumMfcc");
143 const uint32_t kwsNumAudioWindows = ctx.Get<uint32_t>("kwsNumAudioWins");
144
145 audio::DsCnnMFCC kwsMfcc = audio::DsCnnMFCC(kwsNumMfccFeats, kwsFrameLength);
146 kwsMfcc.Init();
147
148 /* Deduce the data length required for 1 KWS inference from the network parameters. */
149 auto kwsAudioDataWindowSize = kwsNumAudioWindows * kwsFrameStride +
150 (kwsFrameLength - kwsFrameStride);
151 auto kwsMfccWindowSize = kwsFrameLength;
152 auto kwsMfccWindowStride = kwsFrameStride;
153
154 /* We are choosing to move by half the window size => for a 1 second window size,
155 * this means an overlap of 0.5 seconds. */
156 auto kwsAudioDataStride = kwsAudioDataWindowSize / 2;
157
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100158 info("KWS audio data window size %" PRIu32 "\n", kwsAudioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000159
160 /* Stride must be multiple of mfcc features window stride to re-use features. */
161 if (0 != kwsAudioDataStride % kwsMfccWindowStride) {
162 kwsAudioDataStride -= kwsAudioDataStride % kwsMfccWindowStride;
163 }
164
165 auto kwsMfccVectorsInAudioStride = kwsAudioDataStride/kwsMfccWindowStride;
166
167 /* We expect to be sampling 1 second worth of data at a time
168 * NOTE: This is only used for time stamp calculation. */
169 const float kwsAudioParamsSecondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
170
171 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
172
173 /* Creating a mfcc features sliding window for the data required for 1 inference. */
174 auto kwsAudioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
175 get_audio_array(currentIndex),
176 kwsAudioDataWindowSize, kwsMfccWindowSize,
177 kwsMfccWindowStride);
178
179 /* Creating a sliding window through the whole audio clip. */
180 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
181 get_audio_array(currentIndex),
182 get_audio_array_size(currentIndex),
183 kwsAudioDataWindowSize, kwsAudioDataStride);
184
185 /* Calculate number of the feature vectors in the window overlap region.
186 * These feature vectors will be reused.*/
187 size_t numberOfReusedFeatureVectors = kwsAudioMFCCWindowSlider.TotalStrides() + 1
188 - kwsMfccVectorsInAudioStride;
189
190 auto kwsMfccFeatureCalc = GetFeatureCalculator(kwsMfcc, kwsInputTensor,
191 numberOfReusedFeatureVectors);
192
193 if (!kwsMfccFeatureCalc){
194 return output;
195 }
196
197 /* Container for KWS results. */
198 std::vector<arm::app::kws::KwsResult> kwsResults;
199
200 /* Display message on the LCD - inference running. */
201 auto& platform = ctx.Get<hal_platform&>("platform");
202 std::string str_inf{"Running KWS inference... "};
203 platform.data_psn->present_data_text(
204 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100205 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000206
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100207 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
alexander3c798932021-03-26 21:42:19 +0000208 currentIndex, get_filename(currentIndex));
209
210 /* Start sliding through audio clip. */
211 while (audioDataSlider.HasNext()) {
212 const int16_t* inferenceWindow = audioDataSlider.Next();
213
214 /* We moved to the next window - set the features sliding to the new address. */
215 kwsAudioMFCCWindowSlider.Reset(inferenceWindow);
216
217 /* The first window does not have cache ready. */
218 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
219
220 /* Start calculating features inside one audio sliding window. */
221 while (kwsAudioMFCCWindowSlider.HasNext()) {
222 const int16_t* kwsMfccWindow = kwsAudioMFCCWindowSlider.Next();
223 std::vector<int16_t> kwsMfccAudioData =
224 std::vector<int16_t>(kwsMfccWindow, kwsMfccWindow + kwsMfccWindowSize);
225
226 /* Compute features for this window and write them to input tensor. */
227 kwsMfccFeatureCalc(kwsMfccAudioData,
228 kwsAudioMFCCWindowSlider.Index(),
229 useCache,
230 kwsMfccVectorsInAudioStride);
231 }
232
233 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
234 audioDataSlider.TotalStrides() + 1);
235
236 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100237 if (!RunInference(kwsModel, profiler)) {
238 printf_err("KWS inference failed\n");
239 return output;
240 }
alexander3c798932021-03-26 21:42:19 +0000241
242 std::vector<ClassificationResult> kwsClassificationResult;
243 auto& kwsClassifier = ctx.Get<KwsClassifier&>("kwsclassifier");
244
245 kwsClassifier.GetClassificationResults(
246 kwsOutputTensor, kwsClassificationResult,
247 ctx.Get<std::vector<std::string>&>("kwslabels"), 1);
248
249 kwsResults.emplace_back(
250 kws::KwsResult(
251 kwsClassificationResult,
252 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * kwsAudioDataStride,
253 audioDataSlider.Index(), kwsScoreThreshold)
254 );
255
256 /* Keyword detected. */
257 if (kwsClassificationResult[0].m_labelIdx == ctx.Get<uint32_t>("keywordindex")) {
258 output.asrAudioStart = inferenceWindow + kwsAudioDataWindowSize;
259 output.asrAudioSamples = get_audio_array_size(currentIndex) -
260 (audioDataSlider.NextWindowStartIndex() -
261 kwsAudioDataStride + kwsAudioDataWindowSize);
262 break;
263 }
264
265#if VERIFY_TEST_OUTPUT
266 arm::app::DumpTensor(kwsOutputTensor);
267#endif /* VERIFY_TEST_OUTPUT */
268
269 } /* while (audioDataSlider.HasNext()) */
270
271 /* Erase. */
272 str_inf = std::string(str_inf.size(), ' ');
273 platform.data_psn->present_data_text(
274 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100275 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000276
alexanderc350cdc2021-04-29 20:36:09 +0100277 if (!PresentInferenceResult(platform, kwsResults)) {
alexander3c798932021-03-26 21:42:19 +0000278 return output;
279 }
280
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100281 profiler.PrintProfilingResult();
282
alexander3c798932021-03-26 21:42:19 +0000283 output.executionSuccess = true;
284 return output;
285 }
286
287 /**
288 * @brief Performs the ASR pipeline.
289 *
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100290 * @param[in,out] ctx pointer to the application context object
291 * @param[in] kwsOutput struct containing pointer to audio data where ASR should begin
alexander3c798932021-03-26 21:42:19 +0000292 * and how much data to process
293 * @return bool true if pipeline executed without failure
294 */
295 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput) {
296 constexpr uint32_t dataPsnTxtInfStartX = 20;
297 constexpr uint32_t dataPsnTxtInfStartY = 40;
298
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100299 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +0000300 auto& platform = ctx.Get<hal_platform&>("platform");
301 platform.data_psn->clear(COLOR_BLACK);
302
303 /* Get model reference. */
304 auto& asrModel = ctx.Get<Model&>("asrmodel");
305 if (!asrModel.IsInited()) {
306 printf_err("ASR model has not been initialised\n");
307 return false;
308 }
309
310 /* Get score threshold to be applied for the classifier (post-inference). */
311 auto asrScoreThreshold = ctx.Get<float>("asrscoreThreshold");
312
313 /* Dimensions of the tensor should have been verified by the callee. */
314 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
315 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
316 const uint32_t asrInputRows = asrInputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
317
318 /* Populate ASR MFCC related parameters. */
319 auto asrMfccParamsWinLen = ctx.Get<uint32_t>("asrframeLength");
320 auto asrMfccParamsWinStride = ctx.Get<uint32_t>("asrframeStride");
321
322 /* Populate ASR inference context and inner lengths for input. */
323 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
324 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
325
326 /* Make sure the input tensor supports the above context and inner lengths. */
327 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100328 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
329 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000330 return false;
331 }
332
333 /* Audio data stride corresponds to inputInnerLen feature vectors. */
334 const uint32_t asrAudioParamsWinLen = (asrInputRows - 1) *
335 asrMfccParamsWinStride + (asrMfccParamsWinLen);
336 const uint32_t asrAudioParamsWinStride = asrInputInnerLen * asrMfccParamsWinStride;
337 const float asrAudioParamsSecondsPerSample =
338 (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
339
340 /* Get pre/post-processing objects */
341 auto& asrPrep = ctx.Get<audio::asr::Preprocess&>("preprocess");
342 auto& asrPostp = ctx.Get<audio::asr::Postprocess&>("postprocess");
343
344 /* Set default reduction axis for post-processing. */
345 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
346
347 /* Get the remaining audio buffer and respective size from KWS results. */
348 const int16_t* audioArr = kwsOutput.asrAudioStart;
349 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
350
351 /* Audio clip must have enough samples to produce 1 MFCC feature. */
352 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
353 if (audioArrSize < asrMfccParamsWinLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100354 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
355 asrMfccParamsWinLen);
alexander3c798932021-03-26 21:42:19 +0000356 return false;
357 }
358
359 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100360 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000361 audioBuffer.data(),
362 audioBuffer.size(),
363 asrAudioParamsWinLen,
364 asrAudioParamsWinStride);
365
366 /* Declare a container for results. */
367 std::vector<arm::app::asr::AsrResult> asrResults;
368
369 /* Display message on the LCD - inference running. */
370 std::string str_inf{"Running ASR inference... "};
371 platform.data_psn->present_data_text(
372 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100373 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000374
375 size_t asrInferenceWindowLen = asrAudioParamsWinLen;
376
377 /* Start sliding through audio clip. */
378 while (audioDataSlider.HasNext()) {
379
380 /* If not enough audio see how much can be sent for processing. */
381 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
382 if (nextStartIndex + asrAudioParamsWinLen > audioBuffer.size()) {
383 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
384 }
385
386 const int16_t* asrInferenceWindow = audioDataSlider.Next();
387
388 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
389 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
390
alexander3c798932021-03-26 21:42:19 +0000391 /* Calculate MFCCs, deltas and populate the input tensor. */
392 asrPrep.Invoke(asrInferenceWindow, asrInferenceWindowLen, asrInputTensor);
393
alexander3c798932021-03-26 21:42:19 +0000394 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100395 if (!RunInference(asrModel, profiler)) {
396 printf_err("ASR inference failed\n");
397 return false;
398 }
alexander3c798932021-03-26 21:42:19 +0000399
400 /* Post-process. */
401 asrPostp.Invoke(asrOutputTensor, reductionAxis, !audioDataSlider.HasNext());
402
403 /* Get results. */
404 std::vector<ClassificationResult> asrClassificationResult;
405 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrclassifier");
406 asrClassifier.GetClassificationResults(
407 asrOutputTensor, asrClassificationResult,
408 ctx.Get<std::vector<std::string>&>("asrlabels"), 1);
409
410 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
411 (audioDataSlider.Index() *
412 asrAudioParamsSecondsPerSample *
413 asrAudioParamsWinStride),
414 audioDataSlider.Index(), asrScoreThreshold));
415
416#if VERIFY_TEST_OUTPUT
417 arm::app::DumpTensor(asrOutputTensor, asrOutputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
418#endif /* VERIFY_TEST_OUTPUT */
419
420 /* Erase */
421 str_inf = std::string(str_inf.size(), ' ');
422 platform.data_psn->present_data_text(
423 str_inf.c_str(), str_inf.size(),
424 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
425 }
alexanderc350cdc2021-04-29 20:36:09 +0100426 if (!PresentInferenceResult(platform, asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000427 return false;
428 }
429
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100430 profiler.PrintProfilingResult();
431
alexander3c798932021-03-26 21:42:19 +0000432 return true;
433 }
434
435 /* Audio inference classification handler. */
436 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
437 {
438 auto& platform = ctx.Get<hal_platform&>("platform");
439 platform.data_psn->clear(COLOR_BLACK);
440
441 /* If the request has a valid size, set the audio index. */
442 if (clipIndex < NUMBER_OF_FILES) {
alexanderc350cdc2021-04-29 20:36:09 +0100443 if (!SetAppCtxClipIdx(ctx, clipIndex)) {
alexander3c798932021-03-26 21:42:19 +0000444 return false;
445 }
446 }
447
448 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
449
450 do {
451 KWSOutput kwsOutput = doKws(ctx);
452 if (!kwsOutput.executionSuccess) {
453 return false;
454 }
455
456 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
457 info("Keyword spotted\n");
458 if(!doAsr(ctx, kwsOutput)) {
459 printf_err("ASR failed");
460 return false;
461 }
462 }
463
alexanderc350cdc2021-04-29 20:36:09 +0100464 IncrementAppCtxClipIdx(ctx);
alexander3c798932021-03-26 21:42:19 +0000465
466 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
467
468 return true;
469 }
470
alexanderc350cdc2021-04-29 20:36:09 +0100471 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
alexander3c798932021-03-26 21:42:19 +0000472 {
473 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
474
475 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
476 ctx.Set<uint32_t>("clipIndex", 0);
477 return;
478 }
479 ++curAudioIdx;
480 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
481 }
482
alexanderc350cdc2021-04-29 20:36:09 +0100483 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx)
alexander3c798932021-03-26 21:42:19 +0000484 {
485 if (idx >= NUMBER_OF_FILES) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100486 printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n",
alexander3c798932021-03-26 21:42:19 +0000487 idx, NUMBER_OF_FILES);
488 return false;
489 }
490 ctx.Set<uint32_t>("clipIndex", idx);
491 return true;
492 }
493
alexanderc350cdc2021-04-29 20:36:09 +0100494 static bool PresentInferenceResult(hal_platform& platform,
495 std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000496 {
497 constexpr uint32_t dataPsnTxtStartX1 = 20;
498 constexpr uint32_t dataPsnTxtStartY1 = 30;
499 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
500
501 platform.data_psn->set_text_color(COLOR_GREEN);
502
503 /* Display each result. */
504 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
505
506 for (uint32_t i = 0; i < results.size(); ++i) {
507
508 std::string topKeyword{"<none>"};
509 float score = 0.f;
510
alexanderc350cdc2021-04-29 20:36:09 +0100511 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000512 topKeyword = results[i].m_resultVec[0].m_label;
513 score = results[i].m_resultVec[0].m_normalisedVal;
514 }
515
516 std::string resultStr =
517 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
518 std::string{"s: "} + topKeyword + std::string{" ("} +
519 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
520
521 platform.data_psn->present_data_text(
522 resultStr.c_str(), resultStr.size(),
523 dataPsnTxtStartX1, rowIdx1, 0);
524 rowIdx1 += dataPsnTxtYIncr;
525
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100526 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
alexander3c798932021-03-26 21:42:19 +0000527 results[i].m_timeStamp, results[i].m_inferenceNumber,
528 results[i].m_threshold);
529 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100530 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
alexander3c798932021-03-26 21:42:19 +0000531 results[i].m_resultVec[j].m_label.c_str(),
532 results[i].m_resultVec[j].m_normalisedVal);
533 }
534 }
535
536 return true;
537 }
538
alexanderc350cdc2021-04-29 20:36:09 +0100539 static bool PresentInferenceResult(hal_platform& platform, std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000540 {
541 constexpr uint32_t dataPsnTxtStartX1 = 20;
542 constexpr uint32_t dataPsnTxtStartY1 = 80;
543 constexpr bool allow_multiple_lines = true;
544
545 platform.data_psn->set_text_color(COLOR_GREEN);
546
547 /* Results from multiple inferences should be combined before processing. */
548 std::vector<arm::app::ClassificationResult> combinedResults;
549 for (auto& result : results) {
550 combinedResults.insert(combinedResults.end(),
551 result.m_resultVec.begin(),
552 result.m_resultVec.end());
553 }
554
555 for (auto& result : results) {
556 /* Get the final result string using the decoder. */
557 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
558
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100559 info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber,
alexander3c798932021-03-26 21:42:19 +0000560 infResultStr.c_str());
561 }
562
563 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
564
565 platform.data_psn->present_data_text(
566 finalResultStr.c_str(), finalResultStr.size(),
567 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
568
569 info("Final result: %s\n", finalResultStr.c_str());
570 return true;
571 }
572
573 /**
574 * @brief Generic feature calculator factory.
575 *
576 * Returns lambda function to compute features using features cache.
577 * Real features math is done by a lambda function provided as a parameter.
578 * Features are written to input tensor memory.
579 *
580 * @tparam T feature vector type.
581 * @param inputTensor model input tensor pointer.
582 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
583 * @param compute features calculator function.
584 * @return lambda function to compute features.
585 **/
586 template<class T>
587 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100588 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
589 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000590 {
591 /* Feature cache to be captured by lambda function. */
592 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
593
594 return [=](std::vector<int16_t>& audioDataWindow,
595 size_t index,
596 bool useCache,
597 size_t featuresOverlapIndex)
598 {
599 T* tensorData = tflite::GetTensorData<T>(inputTensor);
600 std::vector<T> features;
601
602 /* Reuse features from cache if cache is ready and sliding windows overlap.
603 * Overlap is in the beginning of sliding window with a size of a feature cache.
604 */
605 if (useCache && index < featureCache.size()) {
606 features = std::move(featureCache[index]);
607 } else {
608 features = std::move(compute(audioDataWindow));
609 }
610 auto size = features.size();
611 auto sizeBytes = sizeof(T) * size;
612 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
613
614 /* Start renewing cache as soon iteration goes out of the windows overlap. */
615 if (index >= featuresOverlapIndex) {
616 featureCache[index - featuresOverlapIndex] = std::move(features);
617 }
618 };
619 }
620
621 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100622 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
623 size_t cacheSize,
624 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
625
626 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
627 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000628 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100629 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000630
631 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100632 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
633 size_t cacheSize,
634 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000635
636 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100637 FeatureCalc<float>(TfLiteTensor* inputTensor,
638 size_t cacheSize,
639 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000640
641
642 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
643 GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
644 {
645 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
646
647 TfLiteQuantization quant = inputTensor->quantization;
648
649 if (kTfLiteAffineQuantization == quant.type) {
650
651 auto* quantParams = (TfLiteAffineQuantization*) quant.params;
652 const float quantScale = quantParams->scale->data[0];
653 const int quantOffset = quantParams->zero_point->data[0];
654
655 switch (inputTensor->type) {
656 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100657 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
658 cacheSize,
659 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
660 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
661 quantScale,
662 quantOffset);
663 }
alexander3c798932021-03-26 21:42:19 +0000664 );
665 break;
666 }
667 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100668 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
669 cacheSize,
670 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
671 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
672 quantScale,
673 quantOffset);
674 }
alexander3c798932021-03-26 21:42:19 +0000675 );
676 break;
677 }
678 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100679 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
680 cacheSize,
681 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
682 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
683 quantScale,
684 quantOffset);
685 }
alexander3c798932021-03-26 21:42:19 +0000686 );
687 break;
688 }
689 default:
690 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
691 }
692
693
694 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100695 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
696 cacheSize,
697 [&mfcc](std::vector<int16_t>& audioDataWindow) {
698 return mfcc.MfccCompute(audioDataWindow);
699 });
alexander3c798932021-03-26 21:42:19 +0000700 }
701 return mfccFeatureCalc;
702 }
703} /* namespace app */
704} /* namespace arm */