blob: 01aefae0a798ab74baa8bb792cd3f37e2b116d27 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtoned35a6f2022-02-14 11:55:35 +00002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000022#include "ImageUtils.hpp"
alexander3c798932021-03-26 21:42:19 +000023#include "UseCaseCommonUtils.hpp"
Kshitij Sisodia76a15802021-12-24 11:05:11 +000024#include "MicroNetKwsModel.hpp"
25#include "MicroNetKwsMfcc.hpp"
alexander3c798932021-03-26 21:42:19 +000026#include "Classifier.hpp"
27#include "KwsResult.hpp"
28#include "Wav2LetterMfcc.hpp"
29#include "Wav2LetterPreprocess.hpp"
30#include "Wav2LetterPostprocess.hpp"
Richard Burton4e002792022-05-04 09:45:02 +010031#include "KwsProcessing.hpp"
alexander3c798932021-03-26 21:42:19 +000032#include "AsrResult.hpp"
33#include "AsrClassifier.hpp"
34#include "OutputDecode.hpp"
alexander31ae9f02022-02-10 16:15:54 +000035#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000036
37
38using KwsClassifier = arm::app::Classifier;
39
40namespace arm {
41namespace app {
42
alexander3c798932021-03-26 21:42:19 +000043 struct KWSOutput {
44 bool executionSuccess = false;
45 const int16_t* asrAudioStart = nullptr;
46 int32_t asrAudioSamples = 0;
47 };
48
49 /**
Richard Burton4e002792022-05-04 09:45:02 +010050 * @brief Presents KWS inference results.
51 * @param[in] results Vector of KWS classification results to be displayed.
52 * @return true if successful, false otherwise.
alexander3c798932021-03-26 21:42:19 +000053 **/
Richard Burton4e002792022-05-04 09:45:02 +010054 static bool PresentInferenceResult(std::vector<kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000055
56 /**
Richard Burton4e002792022-05-04 09:45:02 +010057 * @brief Presents ASR inference results.
58 * @param[in] results Vector of ASR classification results to be displayed.
59 * @return true if successful, false otherwise.
alexander3c798932021-03-26 21:42:19 +000060 **/
Richard Burton4e002792022-05-04 09:45:02 +010061 static bool PresentInferenceResult(std::vector<asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000062
63 /**
Richard Burton4e002792022-05-04 09:45:02 +010064 * @brief Performs the KWS pipeline.
65 * @param[in,out] ctx pointer to the application context object
66 * @return struct containing pointer to audio data where ASR should begin
67 * and how much data to process.
alexander3c798932021-03-26 21:42:19 +000068 **/
Richard Burton4e002792022-05-04 09:45:02 +010069 static KWSOutput doKws(ApplicationContext& ctx)
70 {
71 auto& profiler = ctx.Get<Profiler&>("profiler");
72 auto& kwsModel = ctx.Get<Model&>("kwsModel");
73 const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength");
74 const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride");
75 const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold");
alexander3c798932021-03-26 21:42:19 +000076
Richard Burton4e002792022-05-04 09:45:02 +010077 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
78
alexander3c798932021-03-26 21:42:19 +000079 constexpr uint32_t dataPsnTxtInfStartX = 20;
80 constexpr uint32_t dataPsnTxtInfStartY = 40;
81
82 constexpr int minTensorDims = static_cast<int>(
Richard Burton4e002792022-05-04 09:45:02 +010083 (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)?
84 MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx);
alexander3c798932021-03-26 21:42:19 +000085
Richard Burton4e002792022-05-04 09:45:02 +010086 /* Output struct from doing KWS. */
87 KWSOutput output {};
alexander3c798932021-03-26 21:42:19 +000088
alexander3c798932021-03-26 21:42:19 +000089 if (!kwsModel.IsInited()) {
90 printf_err("KWS model has not been initialised\n");
91 return output;
92 }
93
Richard Burton4e002792022-05-04 09:45:02 +010094 /* Get Input and Output tensors for pre/post processing. */
alexander3c798932021-03-26 21:42:19 +000095 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +010096 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
alexander3c798932021-03-26 21:42:19 +000097 if (!kwsInputTensor->dims) {
98 printf_err("Invalid input tensor dims\n");
99 return output;
100 } else if (kwsInputTensor->dims->size < minTensorDims) {
101 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
102 return output;
103 }
104
Richard Burton4e002792022-05-04 09:45:02 +0100105 /* Get input shape for feature extraction. */
106 TfLiteIntArray* inputShape = kwsModel.GetInputShape(0);
107 const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx];
108 const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +0000109
110 /* We expect to be sampling 1 second worth of data at a time
111 * NOTE: This is only used for time stamp calculation. */
Richard Burton4e002792022-05-04 09:45:02 +0100112 const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000113
Richard Burton4e002792022-05-04 09:45:02 +0100114 /* Set up pre and post-processing. */
115 KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames,
116 kwsMfccFrameLength, kwsMfccFrameStride);
alexander3c798932021-03-26 21:42:19 +0000117
Richard Burton4e002792022-05-04 09:45:02 +0100118 std::vector<ClassificationResult> singleInfResult;
119 KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"),
120 ctx.Get<std::vector<std::string>&>("kwsLabels"),
121 singleInfResult);
alexander3c798932021-03-26 21:42:19 +0000122
123 /* Creating a sliding window through the whole audio clip. */
124 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
125 get_audio_array(currentIndex),
126 get_audio_array_size(currentIndex),
Richard Burton4e002792022-05-04 09:45:02 +0100127 preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride);
alexander3c798932021-03-26 21:42:19 +0000128
Richard Burton4e002792022-05-04 09:45:02 +0100129 /* Declare a container to hold kws results from across the whole audio clip. */
130 std::vector<kws::KwsResult> finalResults;
alexander3c798932021-03-26 21:42:19 +0000131
132 /* Display message on the LCD - inference running. */
alexander3c798932021-03-26 21:42:19 +0000133 std::string str_inf{"Running KWS inference... "};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100134 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000135 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100136 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000137
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100138 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
alexander3c798932021-03-26 21:42:19 +0000139 currentIndex, get_filename(currentIndex));
140
141 /* Start sliding through audio clip. */
142 while (audioDataSlider.HasNext()) {
143 const int16_t* inferenceWindow = audioDataSlider.Next();
144
alexander3c798932021-03-26 21:42:19 +0000145 /* The first window does not have cache ready. */
Richard Burton4e002792022-05-04 09:45:02 +0100146 preProcess.m_audioWindowIndex = audioDataSlider.Index();
alexander3c798932021-03-26 21:42:19 +0000147
Richard Burton4e002792022-05-04 09:45:02 +0100148 /* Run the pre-processing, inference and post-processing. */
149 if (!preProcess.DoPreProcess(inferenceWindow, audio::MicroNetKwsMFCC::ms_defaultSamplingFreq)) {
150 printf_err("KWS Pre-processing failed.");
151 return output;
152 }
alexander3c798932021-03-26 21:42:19 +0000153
Richard Burton4e002792022-05-04 09:45:02 +0100154 if (!RunInference(kwsModel, profiler)) {
155 printf_err("KWS Inference failed.");
156 return output;
157 }
158
159 if (!postProcess.DoPostProcess()) {
160 printf_err("KWS Post-processing failed.");
161 return output;
alexander3c798932021-03-26 21:42:19 +0000162 }
163
164 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
165 audioDataSlider.TotalStrides() + 1);
166
Richard Burton4e002792022-05-04 09:45:02 +0100167 /* Add results from this window to our final results vector. */
168 finalResults.emplace_back(
169 kws::KwsResult(singleInfResult,
170 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride,
171 audioDataSlider.Index(), kwsScoreThreshold));
alexander3c798932021-03-26 21:42:19 +0000172
Richard Burton4e002792022-05-04 09:45:02 +0100173 /* Break out when trigger keyword is detected. */
174 if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword")
175 && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) {
176 output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize;
alexander3c798932021-03-26 21:42:19 +0000177 output.asrAudioSamples = get_audio_array_size(currentIndex) -
178 (audioDataSlider.NextWindowStartIndex() -
Richard Burton4e002792022-05-04 09:45:02 +0100179 preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000180 break;
181 }
182
183#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100184 DumpTensor(kwsOutputTensor);
alexander3c798932021-03-26 21:42:19 +0000185#endif /* VERIFY_TEST_OUTPUT */
186
187 } /* while (audioDataSlider.HasNext()) */
188
189 /* Erase. */
190 str_inf = std::string(str_inf.size(), ' ');
Richard Burton4e002792022-05-04 09:45:02 +0100191 hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
192 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000193
Richard Burton4e002792022-05-04 09:45:02 +0100194 if (!PresentInferenceResult(finalResults)) {
alexander3c798932021-03-26 21:42:19 +0000195 return output;
196 }
197
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100198 profiler.PrintProfilingResult();
199
alexander3c798932021-03-26 21:42:19 +0000200 output.executionSuccess = true;
201 return output;
202 }
203
204 /**
Richard Burton4e002792022-05-04 09:45:02 +0100205 * @brief Performs the ASR pipeline.
206 * @param[in,out] ctx Pointer to the application context object.
207 * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin
208 * and how much data to process.
209 * @return true if pipeline executed without failure.
210 **/
211 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput)
212 {
213 auto& asrModel = ctx.Get<Model&>("asrModel");
214 auto& profiler = ctx.Get<Profiler&>("profiler");
215 auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength");
216 auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride");
217 auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold");
218 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
219
alexander3c798932021-03-26 21:42:19 +0000220 constexpr uint32_t dataPsnTxtInfStartX = 20;
221 constexpr uint32_t dataPsnTxtInfStartY = 40;
222
alexander3c798932021-03-26 21:42:19 +0000223 if (!asrModel.IsInited()) {
224 printf_err("ASR model has not been initialised\n");
225 return false;
226 }
227
Richard Burton4e002792022-05-04 09:45:02 +0100228 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000229
Richard Burton4e002792022-05-04 09:45:02 +0100230 /* Get Input and Output tensors for pre/post processing. */
alexander3c798932021-03-26 21:42:19 +0000231 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
232 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
alexander3c798932021-03-26 21:42:19 +0000233
Richard Burton4e002792022-05-04 09:45:02 +0100234 /* Get input shape. Dimensions of the tensor should have been verified by
235 * the callee. */
236 TfLiteIntArray* inputShape = asrModel.GetInputShape(0);
alexander3c798932021-03-26 21:42:19 +0000237
Richard Burton4e002792022-05-04 09:45:02 +0100238
239 const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +0000240 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
241
242 /* Make sure the input tensor supports the above context and inner lengths. */
243 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100244 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
245 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000246 return false;
247 }
248
249 /* Audio data stride corresponds to inputInnerLen feature vectors. */
Richard Burton4e002792022-05-04 09:45:02 +0100250 const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen);
251 const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride;
252 const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000253
254 /* Get the remaining audio buffer and respective size from KWS results. */
255 const int16_t* audioArr = kwsOutput.asrAudioStart;
256 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
257
258 /* Audio clip must have enough samples to produce 1 MFCC feature. */
259 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
Richard Burton4e002792022-05-04 09:45:02 +0100260 if (audioArrSize < asrMfccFrameLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100261 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
Richard Burton4e002792022-05-04 09:45:02 +0100262 asrMfccFrameLen);
alexander3c798932021-03-26 21:42:19 +0000263 return false;
264 }
265
266 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100267 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000268 audioBuffer.data(),
269 audioBuffer.size(),
Richard Burton4e002792022-05-04 09:45:02 +0100270 asrAudioDataWindowLen,
271 asrAudioDataWindowStride);
alexander3c798932021-03-26 21:42:19 +0000272
273 /* Declare a container for results. */
Richard Burton4e002792022-05-04 09:45:02 +0100274 std::vector<asr::AsrResult> asrResults;
alexander3c798932021-03-26 21:42:19 +0000275
276 /* Display message on the LCD - inference running. */
277 std::string str_inf{"Running ASR inference... "};
Richard Burton4e002792022-05-04 09:45:02 +0100278 hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100279 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000280
Richard Burton4e002792022-05-04 09:45:02 +0100281 size_t asrInferenceWindowLen = asrAudioDataWindowLen;
alexander3c798932021-03-26 21:42:19 +0000282
Richard Burton4e002792022-05-04 09:45:02 +0100283 /* Set up pre and post-processing objects. */
284 AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures,
285 inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
286 asrMfccFrameLen, asrMfccFrameStride);
287
288 std::vector<ClassificationResult> singleInfResult;
289 const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen);
290 AsrPostProcess asrPostProcess = AsrPostProcess(
291 asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"),
292 ctx.Get<std::vector<std::string>&>("asrLabels"),
293 singleInfResult, outputCtxLen,
294 Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
295 );
alexander3c798932021-03-26 21:42:19 +0000296 /* Start sliding through audio clip. */
297 while (audioDataSlider.HasNext()) {
298
299 /* If not enough audio see how much can be sent for processing. */
300 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
Richard Burton4e002792022-05-04 09:45:02 +0100301 if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) {
alexander3c798932021-03-26 21:42:19 +0000302 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
303 }
304
305 const int16_t* asrInferenceWindow = audioDataSlider.Next();
306
307 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
308 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
309
Richard Burton4e002792022-05-04 09:45:02 +0100310 /* Run the pre-processing, inference and post-processing. */
311 if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) {
312 printf_err("ASR pre-processing failed.");
313 return false;
314 }
alexander3c798932021-03-26 21:42:19 +0000315
alexander3c798932021-03-26 21:42:19 +0000316 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100317 if (!RunInference(asrModel, profiler)) {
318 printf_err("ASR inference failed\n");
319 return false;
320 }
alexander3c798932021-03-26 21:42:19 +0000321
Richard Burton4e002792022-05-04 09:45:02 +0100322 /* Post processing needs to know if we are on the last audio window. */
323 asrPostProcess.m_lastIteration = !audioDataSlider.HasNext();
324 if (!asrPostProcess.DoPostProcess()) {
325 printf_err("ASR post-processing failed.");
326 return false;
327 }
alexander3c798932021-03-26 21:42:19 +0000328
329 /* Get results. */
330 std::vector<ClassificationResult> asrClassificationResult;
Richard Burton4e002792022-05-04 09:45:02 +0100331 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier");
alexander3c798932021-03-26 21:42:19 +0000332 asrClassifier.GetClassificationResults(
333 asrOutputTensor, asrClassificationResult,
Richard Burton4e002792022-05-04 09:45:02 +0100334 ctx.Get<std::vector<std::string>&>("asrLabels"), 1);
alexander3c798932021-03-26 21:42:19 +0000335
336 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
337 (audioDataSlider.Index() *
338 asrAudioParamsSecondsPerSample *
Richard Burton4e002792022-05-04 09:45:02 +0100339 asrAudioDataWindowStride),
alexander3c798932021-03-26 21:42:19 +0000340 audioDataSlider.Index(), asrScoreThreshold));
341
342#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100343 armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
alexander3c798932021-03-26 21:42:19 +0000344#endif /* VERIFY_TEST_OUTPUT */
345
346 /* Erase */
347 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100348 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000349 str_inf.c_str(), str_inf.size(),
350 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
351 }
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100352 if (!PresentInferenceResult(asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000353 return false;
354 }
355
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100356 profiler.PrintProfilingResult();
357
alexander3c798932021-03-26 21:42:19 +0000358 return true;
359 }
360
Richard Burton4e002792022-05-04 09:45:02 +0100361 /* KWS and ASR inference handler. */
alexander3c798932021-03-26 21:42:19 +0000362 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
363 {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100364 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000365
366 /* If the request has a valid size, set the audio index. */
367 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100368 if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) {
alexander3c798932021-03-26 21:42:19 +0000369 return false;
370 }
371 }
372
373 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
374
375 do {
376 KWSOutput kwsOutput = doKws(ctx);
377 if (!kwsOutput.executionSuccess) {
Richard Burton4e002792022-05-04 09:45:02 +0100378 printf_err("KWS failed\n");
alexander3c798932021-03-26 21:42:19 +0000379 return false;
380 }
381
382 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
Richard Burton4e002792022-05-04 09:45:02 +0100383 info("Trigger keyword spotted\n");
alexander3c798932021-03-26 21:42:19 +0000384 if(!doAsr(ctx, kwsOutput)) {
Richard Burton4e002792022-05-04 09:45:02 +0100385 printf_err("ASR failed\n");
alexander3c798932021-03-26 21:42:19 +0000386 return false;
387 }
388 }
389
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100390 IncrementAppCtxIfmIdx(ctx,"kws_asr");
alexander3c798932021-03-26 21:42:19 +0000391
392 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
393
394 return true;
395 }
396
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100397 static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000398 {
399 constexpr uint32_t dataPsnTxtStartX1 = 20;
400 constexpr uint32_t dataPsnTxtStartY1 = 30;
401 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
402
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100403 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000404
405 /* Display each result. */
406 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
407
Richard Burton4e002792022-05-04 09:45:02 +0100408 for (auto & result : results) {
alexander3c798932021-03-26 21:42:19 +0000409 std::string topKeyword{"<none>"};
410 float score = 0.f;
411
Richard Burton4e002792022-05-04 09:45:02 +0100412 if (!result.m_resultVec.empty()) {
413 topKeyword = result.m_resultVec[0].m_label;
414 score = result.m_resultVec[0].m_normalisedVal;
alexander3c798932021-03-26 21:42:19 +0000415 }
416
417 std::string resultStr =
Richard Burton4e002792022-05-04 09:45:02 +0100418 std::string{"@"} + std::to_string(result.m_timeStamp) +
alexander3c798932021-03-26 21:42:19 +0000419 std::string{"s: "} + topKeyword + std::string{" ("} +
420 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
421
Richard Burton4e002792022-05-04 09:45:02 +0100422 hal_lcd_display_text(resultStr.c_str(), resultStr.size(),
423 dataPsnTxtStartX1, rowIdx1, 0);
alexander3c798932021-03-26 21:42:19 +0000424 rowIdx1 += dataPsnTxtYIncr;
425
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100426 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
Richard Burton4e002792022-05-04 09:45:02 +0100427 result.m_timeStamp, result.m_inferenceNumber,
428 result.m_threshold);
429 for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100430 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
Richard Burton4e002792022-05-04 09:45:02 +0100431 result.m_resultVec[j].m_label.c_str(),
432 result.m_resultVec[j].m_normalisedVal);
alexander3c798932021-03-26 21:42:19 +0000433 }
434 }
435
436 return true;
437 }
438
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100439 static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000440 {
441 constexpr uint32_t dataPsnTxtStartX1 = 20;
442 constexpr uint32_t dataPsnTxtStartY1 = 80;
443 constexpr bool allow_multiple_lines = true;
444
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100445 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000446
447 /* Results from multiple inferences should be combined before processing. */
448 std::vector<arm::app::ClassificationResult> combinedResults;
449 for (auto& result : results) {
450 combinedResults.insert(combinedResults.end(),
451 result.m_resultVec.begin(),
452 result.m_resultVec.end());
453 }
454
455 for (auto& result : results) {
456 /* Get the final result string using the decoder. */
457 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
458
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100459 info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber,
alexander3c798932021-03-26 21:42:19 +0000460 infResultStr.c_str());
461 }
462
463 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
464
Richard Burton4e002792022-05-04 09:45:02 +0100465 hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
466 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
alexander3c798932021-03-26 21:42:19 +0000467
468 info("Final result: %s\n", finalResultStr.c_str());
469 return true;
470 }
471
alexander3c798932021-03-26 21:42:19 +0000472} /* namespace app */
473} /* namespace arm */