blob: 9427ae0dc02d296b4779dbf5fab9044e5b8392b7 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtoned35a6f2022-02-14 11:55:35 +00002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000022#include "ImageUtils.hpp"
alexander3c798932021-03-26 21:42:19 +000023#include "UseCaseCommonUtils.hpp"
Kshitij Sisodia76a15802021-12-24 11:05:11 +000024#include "MicroNetKwsModel.hpp"
25#include "MicroNetKwsMfcc.hpp"
alexander3c798932021-03-26 21:42:19 +000026#include "Classifier.hpp"
27#include "KwsResult.hpp"
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010028#include "Wav2LetterModel.hpp"
alexander3c798932021-03-26 21:42:19 +000029#include "Wav2LetterMfcc.hpp"
30#include "Wav2LetterPreprocess.hpp"
31#include "Wav2LetterPostprocess.hpp"
Richard Burton4e002792022-05-04 09:45:02 +010032#include "KwsProcessing.hpp"
alexander3c798932021-03-26 21:42:19 +000033#include "AsrResult.hpp"
34#include "AsrClassifier.hpp"
35#include "OutputDecode.hpp"
alexander31ae9f02022-02-10 16:15:54 +000036#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000037
38
39using KwsClassifier = arm::app::Classifier;
40
41namespace arm {
42namespace app {
43
alexander3c798932021-03-26 21:42:19 +000044 struct KWSOutput {
45 bool executionSuccess = false;
46 const int16_t* asrAudioStart = nullptr;
47 int32_t asrAudioSamples = 0;
48 };
49
50 /**
Richard Burton4e002792022-05-04 09:45:02 +010051 * @brief Presents KWS inference results.
52 * @param[in] results Vector of KWS classification results to be displayed.
53 * @return true if successful, false otherwise.
alexander3c798932021-03-26 21:42:19 +000054 **/
Richard Burton4e002792022-05-04 09:45:02 +010055 static bool PresentInferenceResult(std::vector<kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000056
57 /**
Richard Burton4e002792022-05-04 09:45:02 +010058 * @brief Presents ASR inference results.
59 * @param[in] results Vector of ASR classification results to be displayed.
60 * @return true if successful, false otherwise.
alexander3c798932021-03-26 21:42:19 +000061 **/
Richard Burton4e002792022-05-04 09:45:02 +010062 static bool PresentInferenceResult(std::vector<asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000063
64 /**
Richard Burton4e002792022-05-04 09:45:02 +010065 * @brief Performs the KWS pipeline.
66 * @param[in,out] ctx pointer to the application context object
67 * @return struct containing pointer to audio data where ASR should begin
68 * and how much data to process.
alexander3c798932021-03-26 21:42:19 +000069 **/
Richard Burton4e002792022-05-04 09:45:02 +010070 static KWSOutput doKws(ApplicationContext& ctx)
71 {
72 auto& profiler = ctx.Get<Profiler&>("profiler");
73 auto& kwsModel = ctx.Get<Model&>("kwsModel");
74 const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength");
75 const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride");
76 const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold");
alexander3c798932021-03-26 21:42:19 +000077
Richard Burton4e002792022-05-04 09:45:02 +010078 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
79
alexander3c798932021-03-26 21:42:19 +000080 constexpr uint32_t dataPsnTxtInfStartX = 20;
81 constexpr uint32_t dataPsnTxtInfStartY = 40;
82
83 constexpr int minTensorDims = static_cast<int>(
Richard Burton4e002792022-05-04 09:45:02 +010084 (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)?
85 MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx);
alexander3c798932021-03-26 21:42:19 +000086
Richard Burton4e002792022-05-04 09:45:02 +010087 /* Output struct from doing KWS. */
88 KWSOutput output {};
alexander3c798932021-03-26 21:42:19 +000089
alexander3c798932021-03-26 21:42:19 +000090 if (!kwsModel.IsInited()) {
91 printf_err("KWS model has not been initialised\n");
92 return output;
93 }
94
Richard Burton4e002792022-05-04 09:45:02 +010095 /* Get Input and Output tensors for pre/post processing. */
alexander3c798932021-03-26 21:42:19 +000096 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +010097 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
alexander3c798932021-03-26 21:42:19 +000098 if (!kwsInputTensor->dims) {
99 printf_err("Invalid input tensor dims\n");
100 return output;
101 } else if (kwsInputTensor->dims->size < minTensorDims) {
102 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
103 return output;
104 }
105
Richard Burton4e002792022-05-04 09:45:02 +0100106 /* Get input shape for feature extraction. */
107 TfLiteIntArray* inputShape = kwsModel.GetInputShape(0);
108 const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx];
109 const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +0000110
111 /* We expect to be sampling 1 second worth of data at a time
112 * NOTE: This is only used for time stamp calculation. */
Richard Burton4e002792022-05-04 09:45:02 +0100113 const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000114
Richard Burton4e002792022-05-04 09:45:02 +0100115 /* Set up pre and post-processing. */
116 KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames,
117 kwsMfccFrameLength, kwsMfccFrameStride);
alexander3c798932021-03-26 21:42:19 +0000118
Richard Burton4e002792022-05-04 09:45:02 +0100119 std::vector<ClassificationResult> singleInfResult;
120 KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"),
121 ctx.Get<std::vector<std::string>&>("kwsLabels"),
122 singleInfResult);
alexander3c798932021-03-26 21:42:19 +0000123
124 /* Creating a sliding window through the whole audio clip. */
125 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
126 get_audio_array(currentIndex),
127 get_audio_array_size(currentIndex),
Richard Burton4e002792022-05-04 09:45:02 +0100128 preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride);
alexander3c798932021-03-26 21:42:19 +0000129
Richard Burton4e002792022-05-04 09:45:02 +0100130 /* Declare a container to hold kws results from across the whole audio clip. */
131 std::vector<kws::KwsResult> finalResults;
alexander3c798932021-03-26 21:42:19 +0000132
133 /* Display message on the LCD - inference running. */
alexander3c798932021-03-26 21:42:19 +0000134 std::string str_inf{"Running KWS inference... "};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100135 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000136 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100137 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000138
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100139 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
alexander3c798932021-03-26 21:42:19 +0000140 currentIndex, get_filename(currentIndex));
141
142 /* Start sliding through audio clip. */
143 while (audioDataSlider.HasNext()) {
144 const int16_t* inferenceWindow = audioDataSlider.Next();
145
alexander3c798932021-03-26 21:42:19 +0000146 /* The first window does not have cache ready. */
Richard Burton4e002792022-05-04 09:45:02 +0100147 preProcess.m_audioWindowIndex = audioDataSlider.Index();
alexander3c798932021-03-26 21:42:19 +0000148
Richard Burton4e002792022-05-04 09:45:02 +0100149 /* Run the pre-processing, inference and post-processing. */
150 if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
151 printf_err("KWS Pre-processing failed.");
152 return output;
153 }
alexander3c798932021-03-26 21:42:19 +0000154
Richard Burton4e002792022-05-04 09:45:02 +0100155 if (!RunInference(kwsModel, profiler)) {
156 printf_err("KWS Inference failed.");
157 return output;
158 }
159
160 if (!postProcess.DoPostProcess()) {
161 printf_err("KWS Post-processing failed.");
162 return output;
alexander3c798932021-03-26 21:42:19 +0000163 }
164
165 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
166 audioDataSlider.TotalStrides() + 1);
167
Richard Burton4e002792022-05-04 09:45:02 +0100168 /* Add results from this window to our final results vector. */
169 finalResults.emplace_back(
170 kws::KwsResult(singleInfResult,
171 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride,
172 audioDataSlider.Index(), kwsScoreThreshold));
alexander3c798932021-03-26 21:42:19 +0000173
Richard Burton4e002792022-05-04 09:45:02 +0100174 /* Break out when trigger keyword is detected. */
175 if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword")
176 && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) {
177 output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize;
alexander3c798932021-03-26 21:42:19 +0000178 output.asrAudioSamples = get_audio_array_size(currentIndex) -
179 (audioDataSlider.NextWindowStartIndex() -
Richard Burton4e002792022-05-04 09:45:02 +0100180 preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000181 break;
182 }
183
184#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100185 DumpTensor(kwsOutputTensor);
alexander3c798932021-03-26 21:42:19 +0000186#endif /* VERIFY_TEST_OUTPUT */
187
188 } /* while (audioDataSlider.HasNext()) */
189
190 /* Erase. */
191 str_inf = std::string(str_inf.size(), ' ');
Richard Burton4e002792022-05-04 09:45:02 +0100192 hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
193 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000194
Richard Burton4e002792022-05-04 09:45:02 +0100195 if (!PresentInferenceResult(finalResults)) {
alexander3c798932021-03-26 21:42:19 +0000196 return output;
197 }
198
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100199 profiler.PrintProfilingResult();
200
alexander3c798932021-03-26 21:42:19 +0000201 output.executionSuccess = true;
202 return output;
203 }
204
205 /**
Richard Burton4e002792022-05-04 09:45:02 +0100206 * @brief Performs the ASR pipeline.
207 * @param[in,out] ctx Pointer to the application context object.
208 * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin
209 * and how much data to process.
210 * @return true if pipeline executed without failure.
211 **/
212 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput)
213 {
214 auto& asrModel = ctx.Get<Model&>("asrModel");
215 auto& profiler = ctx.Get<Profiler&>("profiler");
216 auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength");
217 auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride");
218 auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold");
219 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
220
alexander3c798932021-03-26 21:42:19 +0000221 constexpr uint32_t dataPsnTxtInfStartX = 20;
222 constexpr uint32_t dataPsnTxtInfStartY = 40;
223
alexander3c798932021-03-26 21:42:19 +0000224 if (!asrModel.IsInited()) {
225 printf_err("ASR model has not been initialised\n");
226 return false;
227 }
228
Richard Burton4e002792022-05-04 09:45:02 +0100229 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000230
Richard Burton4e002792022-05-04 09:45:02 +0100231 /* Get Input and Output tensors for pre/post processing. */
alexander3c798932021-03-26 21:42:19 +0000232 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
233 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
alexander3c798932021-03-26 21:42:19 +0000234
Richard Burton4e002792022-05-04 09:45:02 +0100235 /* Get input shape. Dimensions of the tensor should have been verified by
236 * the callee. */
237 TfLiteIntArray* inputShape = asrModel.GetInputShape(0);
alexander3c798932021-03-26 21:42:19 +0000238
Richard Burton4e002792022-05-04 09:45:02 +0100239
240 const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +0000241 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
242
243 /* Make sure the input tensor supports the above context and inner lengths. */
244 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100245 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
246 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000247 return false;
248 }
249
250 /* Audio data stride corresponds to inputInnerLen feature vectors. */
Richard Burton4e002792022-05-04 09:45:02 +0100251 const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen);
252 const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride;
253 const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000254
255 /* Get the remaining audio buffer and respective size from KWS results. */
256 const int16_t* audioArr = kwsOutput.asrAudioStart;
257 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
258
259 /* Audio clip must have enough samples to produce 1 MFCC feature. */
260 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
Richard Burton4e002792022-05-04 09:45:02 +0100261 if (audioArrSize < asrMfccFrameLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100262 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
Richard Burton4e002792022-05-04 09:45:02 +0100263 asrMfccFrameLen);
alexander3c798932021-03-26 21:42:19 +0000264 return false;
265 }
266
267 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100268 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000269 audioBuffer.data(),
270 audioBuffer.size(),
Richard Burton4e002792022-05-04 09:45:02 +0100271 asrAudioDataWindowLen,
272 asrAudioDataWindowStride);
alexander3c798932021-03-26 21:42:19 +0000273
274 /* Declare a container for results. */
Richard Burton4e002792022-05-04 09:45:02 +0100275 std::vector<asr::AsrResult> asrResults;
alexander3c798932021-03-26 21:42:19 +0000276
277 /* Display message on the LCD - inference running. */
278 std::string str_inf{"Running ASR inference... "};
Richard Burton4e002792022-05-04 09:45:02 +0100279 hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100280 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000281
Richard Burton4e002792022-05-04 09:45:02 +0100282 size_t asrInferenceWindowLen = asrAudioDataWindowLen;
alexander3c798932021-03-26 21:42:19 +0000283
Richard Burton4e002792022-05-04 09:45:02 +0100284 /* Set up pre and post-processing objects. */
285 AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures,
286 inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
287 asrMfccFrameLen, asrMfccFrameStride);
288
289 std::vector<ClassificationResult> singleInfResult;
290 const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen);
291 AsrPostProcess asrPostProcess = AsrPostProcess(
292 asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"),
293 ctx.Get<std::vector<std::string>&>("asrLabels"),
294 singleInfResult, outputCtxLen,
295 Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
296 );
alexander3c798932021-03-26 21:42:19 +0000297 /* Start sliding through audio clip. */
298 while (audioDataSlider.HasNext()) {
299
300 /* If not enough audio see how much can be sent for processing. */
301 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
Richard Burton4e002792022-05-04 09:45:02 +0100302 if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) {
alexander3c798932021-03-26 21:42:19 +0000303 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
304 }
305
306 const int16_t* asrInferenceWindow = audioDataSlider.Next();
307
308 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
309 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
310
Richard Burton4e002792022-05-04 09:45:02 +0100311 /* Run the pre-processing, inference and post-processing. */
312 if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) {
313 printf_err("ASR pre-processing failed.");
314 return false;
315 }
alexander3c798932021-03-26 21:42:19 +0000316
alexander3c798932021-03-26 21:42:19 +0000317 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100318 if (!RunInference(asrModel, profiler)) {
319 printf_err("ASR inference failed\n");
320 return false;
321 }
alexander3c798932021-03-26 21:42:19 +0000322
Richard Burton4e002792022-05-04 09:45:02 +0100323 /* Post processing needs to know if we are on the last audio window. */
324 asrPostProcess.m_lastIteration = !audioDataSlider.HasNext();
325 if (!asrPostProcess.DoPostProcess()) {
326 printf_err("ASR post-processing failed.");
327 return false;
328 }
alexander3c798932021-03-26 21:42:19 +0000329
330 /* Get results. */
331 std::vector<ClassificationResult> asrClassificationResult;
Richard Burton4e002792022-05-04 09:45:02 +0100332 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier");
alexander3c798932021-03-26 21:42:19 +0000333 asrClassifier.GetClassificationResults(
334 asrOutputTensor, asrClassificationResult,
Richard Burton4e002792022-05-04 09:45:02 +0100335 ctx.Get<std::vector<std::string>&>("asrLabels"), 1);
alexander3c798932021-03-26 21:42:19 +0000336
337 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
338 (audioDataSlider.Index() *
339 asrAudioParamsSecondsPerSample *
Richard Burton4e002792022-05-04 09:45:02 +0100340 asrAudioDataWindowStride),
alexander3c798932021-03-26 21:42:19 +0000341 audioDataSlider.Index(), asrScoreThreshold));
342
343#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100344 armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
alexander3c798932021-03-26 21:42:19 +0000345#endif /* VERIFY_TEST_OUTPUT */
346
347 /* Erase */
348 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100349 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000350 str_inf.c_str(), str_inf.size(),
351 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
352 }
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100353 if (!PresentInferenceResult(asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000354 return false;
355 }
356
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100357 profiler.PrintProfilingResult();
358
alexander3c798932021-03-26 21:42:19 +0000359 return true;
360 }
361
Richard Burton4e002792022-05-04 09:45:02 +0100362 /* KWS and ASR inference handler. */
alexander3c798932021-03-26 21:42:19 +0000363 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
364 {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100365 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000366
367 /* If the request has a valid size, set the audio index. */
368 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100369 if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) {
alexander3c798932021-03-26 21:42:19 +0000370 return false;
371 }
372 }
373
374 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
375
376 do {
377 KWSOutput kwsOutput = doKws(ctx);
378 if (!kwsOutput.executionSuccess) {
Richard Burton4e002792022-05-04 09:45:02 +0100379 printf_err("KWS failed\n");
alexander3c798932021-03-26 21:42:19 +0000380 return false;
381 }
382
383 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
Richard Burton4e002792022-05-04 09:45:02 +0100384 info("Trigger keyword spotted\n");
alexander3c798932021-03-26 21:42:19 +0000385 if(!doAsr(ctx, kwsOutput)) {
Richard Burton4e002792022-05-04 09:45:02 +0100386 printf_err("ASR failed\n");
alexander3c798932021-03-26 21:42:19 +0000387 return false;
388 }
389 }
390
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100391 IncrementAppCtxIfmIdx(ctx,"kws_asr");
alexander3c798932021-03-26 21:42:19 +0000392
393 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
394
395 return true;
396 }
397
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100398 static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000399 {
400 constexpr uint32_t dataPsnTxtStartX1 = 20;
401 constexpr uint32_t dataPsnTxtStartY1 = 30;
402 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
403
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100404 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000405
406 /* Display each result. */
407 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
408
Richard Burton4e002792022-05-04 09:45:02 +0100409 for (auto & result : results) {
alexander3c798932021-03-26 21:42:19 +0000410 std::string topKeyword{"<none>"};
411 float score = 0.f;
412
Richard Burton4e002792022-05-04 09:45:02 +0100413 if (!result.m_resultVec.empty()) {
414 topKeyword = result.m_resultVec[0].m_label;
415 score = result.m_resultVec[0].m_normalisedVal;
alexander3c798932021-03-26 21:42:19 +0000416 }
417
418 std::string resultStr =
Richard Burton4e002792022-05-04 09:45:02 +0100419 std::string{"@"} + std::to_string(result.m_timeStamp) +
alexander3c798932021-03-26 21:42:19 +0000420 std::string{"s: "} + topKeyword + std::string{" ("} +
421 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
422
Richard Burton4e002792022-05-04 09:45:02 +0100423 hal_lcd_display_text(resultStr.c_str(), resultStr.size(),
424 dataPsnTxtStartX1, rowIdx1, 0);
alexander3c798932021-03-26 21:42:19 +0000425 rowIdx1 += dataPsnTxtYIncr;
426
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100427 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
Richard Burton4e002792022-05-04 09:45:02 +0100428 result.m_timeStamp, result.m_inferenceNumber,
429 result.m_threshold);
430 for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100431 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
Richard Burton4e002792022-05-04 09:45:02 +0100432 result.m_resultVec[j].m_label.c_str(),
433 result.m_resultVec[j].m_normalisedVal);
alexander3c798932021-03-26 21:42:19 +0000434 }
435 }
436
437 return true;
438 }
439
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100440 static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000441 {
442 constexpr uint32_t dataPsnTxtStartX1 = 20;
443 constexpr uint32_t dataPsnTxtStartY1 = 80;
444 constexpr bool allow_multiple_lines = true;
445
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100446 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000447
448 /* Results from multiple inferences should be combined before processing. */
449 std::vector<arm::app::ClassificationResult> combinedResults;
450 for (auto& result : results) {
451 combinedResults.insert(combinedResults.end(),
452 result.m_resultVec.begin(),
453 result.m_resultVec.end());
454 }
455
456 for (auto& result : results) {
457 /* Get the final result string using the decoder. */
458 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
459
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100460 info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber,
alexander3c798932021-03-26 21:42:19 +0000461 infResultStr.c_str());
462 }
463
464 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
465
Richard Burton4e002792022-05-04 09:45:02 +0100466 hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
467 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
alexander3c798932021-03-26 21:42:19 +0000468
469 info("Final result: %s\n", finalResultStr.c_str());
470 return true;
471 }
472
alexander3c798932021-03-26 21:42:19 +0000473} /* namespace app */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100474} /* namespace arm */