blob: e7336051ecd8feac6878213fe7b77b3ff8375f18 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtonf32a86a2022-11-15 11:46:11 +00002 * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "hal.h"
20#include "InputFiles.hpp"
21#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000022#include "ImageUtils.hpp"
alexander3c798932021-03-26 21:42:19 +000023#include "UseCaseCommonUtils.hpp"
Kshitij Sisodia76a15802021-12-24 11:05:11 +000024#include "MicroNetKwsModel.hpp"
25#include "MicroNetKwsMfcc.hpp"
alexander3c798932021-03-26 21:42:19 +000026#include "Classifier.hpp"
27#include "KwsResult.hpp"
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010028#include "Wav2LetterModel.hpp"
alexander3c798932021-03-26 21:42:19 +000029#include "Wav2LetterMfcc.hpp"
30#include "Wav2LetterPreprocess.hpp"
31#include "Wav2LetterPostprocess.hpp"
Richard Burton4e002792022-05-04 09:45:02 +010032#include "KwsProcessing.hpp"
alexander3c798932021-03-26 21:42:19 +000033#include "AsrResult.hpp"
34#include "AsrClassifier.hpp"
35#include "OutputDecode.hpp"
alexander31ae9f02022-02-10 16:15:54 +000036#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000037
38
39using KwsClassifier = arm::app::Classifier;
40
41namespace arm {
42namespace app {
43
alexander3c798932021-03-26 21:42:19 +000044 struct KWSOutput {
45 bool executionSuccess = false;
46 const int16_t* asrAudioStart = nullptr;
47 int32_t asrAudioSamples = 0;
48 };
49
50 /**
Richard Burton4e002792022-05-04 09:45:02 +010051 * @brief Presents KWS inference results.
52 * @param[in] results Vector of KWS classification results to be displayed.
53 * @return true if successful, false otherwise.
alexander3c798932021-03-26 21:42:19 +000054 **/
Richard Burton4e002792022-05-04 09:45:02 +010055 static bool PresentInferenceResult(std::vector<kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000056
57 /**
Richard Burton4e002792022-05-04 09:45:02 +010058 * @brief Presents ASR inference results.
59 * @param[in] results Vector of ASR classification results to be displayed.
60 * @return true if successful, false otherwise.
alexander3c798932021-03-26 21:42:19 +000061 **/
Richard Burton4e002792022-05-04 09:45:02 +010062 static bool PresentInferenceResult(std::vector<asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000063
64 /**
Richard Burton4e002792022-05-04 09:45:02 +010065 * @brief Performs the KWS pipeline.
66 * @param[in,out] ctx pointer to the application context object
67 * @return struct containing pointer to audio data where ASR should begin
68 * and how much data to process.
alexander3c798932021-03-26 21:42:19 +000069 **/
Richard Burton4e002792022-05-04 09:45:02 +010070 static KWSOutput doKws(ApplicationContext& ctx)
71 {
72 auto& profiler = ctx.Get<Profiler&>("profiler");
73 auto& kwsModel = ctx.Get<Model&>("kwsModel");
74 const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength");
75 const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride");
76 const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold");
alexander3c798932021-03-26 21:42:19 +000077
Richard Burton4e002792022-05-04 09:45:02 +010078 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
79
alexander3c798932021-03-26 21:42:19 +000080 constexpr uint32_t dataPsnTxtInfStartX = 20;
81 constexpr uint32_t dataPsnTxtInfStartY = 40;
82
83 constexpr int minTensorDims = static_cast<int>(
Richard Burton4e002792022-05-04 09:45:02 +010084 (MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)?
85 MicroNetKwsModel::ms_inputRowsIdx : MicroNetKwsModel::ms_inputColsIdx);
alexander3c798932021-03-26 21:42:19 +000086
Richard Burton4e002792022-05-04 09:45:02 +010087 /* Output struct from doing KWS. */
88 KWSOutput output {};
alexander3c798932021-03-26 21:42:19 +000089
alexander3c798932021-03-26 21:42:19 +000090 if (!kwsModel.IsInited()) {
91 printf_err("KWS model has not been initialised\n");
92 return output;
93 }
94
Richard Burton4e002792022-05-04 09:45:02 +010095 /* Get Input and Output tensors for pre/post processing. */
alexander3c798932021-03-26 21:42:19 +000096 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +010097 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
alexander3c798932021-03-26 21:42:19 +000098 if (!kwsInputTensor->dims) {
99 printf_err("Invalid input tensor dims\n");
100 return output;
101 } else if (kwsInputTensor->dims->size < minTensorDims) {
102 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
103 return output;
104 }
105
Richard Burton4e002792022-05-04 09:45:02 +0100106 /* Get input shape for feature extraction. */
107 TfLiteIntArray* inputShape = kwsModel.GetInputShape(0);
108 const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx];
109 const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +0000110
111 /* We expect to be sampling 1 second worth of data at a time
112 * NOTE: This is only used for time stamp calculation. */
Richard Burton4e002792022-05-04 09:45:02 +0100113 const float kwsAudioParamsSecondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000114
Richard Burton4e002792022-05-04 09:45:02 +0100115 /* Set up pre and post-processing. */
116 KwsPreProcess preProcess = KwsPreProcess(kwsInputTensor, numMfccFeatures, numMfccFrames,
117 kwsMfccFrameLength, kwsMfccFrameStride);
alexander3c798932021-03-26 21:42:19 +0000118
Richard Burton4e002792022-05-04 09:45:02 +0100119 std::vector<ClassificationResult> singleInfResult;
120 KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor, ctx.Get<KwsClassifier &>("kwsClassifier"),
121 ctx.Get<std::vector<std::string>&>("kwsLabels"),
122 singleInfResult);
alexander3c798932021-03-26 21:42:19 +0000123
124 /* Creating a sliding window through the whole audio clip. */
125 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
126 get_audio_array(currentIndex),
127 get_audio_array_size(currentIndex),
Richard Burton4e002792022-05-04 09:45:02 +0100128 preProcess.m_audioDataWindowSize, preProcess.m_audioDataStride);
alexander3c798932021-03-26 21:42:19 +0000129
Richard Burton4e002792022-05-04 09:45:02 +0100130 /* Declare a container to hold kws results from across the whole audio clip. */
131 std::vector<kws::KwsResult> finalResults;
alexander3c798932021-03-26 21:42:19 +0000132
133 /* Display message on the LCD - inference running. */
alexander3c798932021-03-26 21:42:19 +0000134 std::string str_inf{"Running KWS inference... "};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100135 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000136 str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100137 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000138
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100139 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
alexander3c798932021-03-26 21:42:19 +0000140 currentIndex, get_filename(currentIndex));
141
142 /* Start sliding through audio clip. */
143 while (audioDataSlider.HasNext()) {
144 const int16_t* inferenceWindow = audioDataSlider.Next();
145
Richard Burton4e002792022-05-04 09:45:02 +0100146 /* Run the pre-processing, inference and post-processing. */
Richard Burtonec5e99b2022-10-05 11:00:37 +0100147 if (!preProcess.DoPreProcess(inferenceWindow, audioDataSlider.Index())) {
Richard Burton4e002792022-05-04 09:45:02 +0100148 printf_err("KWS Pre-processing failed.");
149 return output;
150 }
alexander3c798932021-03-26 21:42:19 +0000151
Richard Burton4e002792022-05-04 09:45:02 +0100152 if (!RunInference(kwsModel, profiler)) {
153 printf_err("KWS Inference failed.");
154 return output;
155 }
156
157 if (!postProcess.DoPostProcess()) {
158 printf_err("KWS Post-processing failed.");
159 return output;
alexander3c798932021-03-26 21:42:19 +0000160 }
161
162 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
163 audioDataSlider.TotalStrides() + 1);
164
Richard Burton4e002792022-05-04 09:45:02 +0100165 /* Add results from this window to our final results vector. */
166 finalResults.emplace_back(
167 kws::KwsResult(singleInfResult,
168 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample * preProcess.m_audioDataStride,
169 audioDataSlider.Index(), kwsScoreThreshold));
alexander3c798932021-03-26 21:42:19 +0000170
Richard Burton4e002792022-05-04 09:45:02 +0100171 /* Break out when trigger keyword is detected. */
172 if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword")
173 && singleInfResult[0].m_normalisedVal > kwsScoreThreshold) {
174 output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize;
alexander3c798932021-03-26 21:42:19 +0000175 output.asrAudioSamples = get_audio_array_size(currentIndex) -
176 (audioDataSlider.NextWindowStartIndex() -
Richard Burton4e002792022-05-04 09:45:02 +0100177 preProcess.m_audioDataStride + preProcess.m_audioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000178 break;
179 }
180
181#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100182 DumpTensor(kwsOutputTensor);
alexander3c798932021-03-26 21:42:19 +0000183#endif /* VERIFY_TEST_OUTPUT */
184
185 } /* while (audioDataSlider.HasNext()) */
186
187 /* Erase. */
188 str_inf = std::string(str_inf.size(), ' ');
Richard Burton4e002792022-05-04 09:45:02 +0100189 hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
190 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000191
Richard Burton4e002792022-05-04 09:45:02 +0100192 if (!PresentInferenceResult(finalResults)) {
alexander3c798932021-03-26 21:42:19 +0000193 return output;
194 }
195
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100196 profiler.PrintProfilingResult();
197
alexander3c798932021-03-26 21:42:19 +0000198 output.executionSuccess = true;
199 return output;
200 }
201
202 /**
Richard Burton4e002792022-05-04 09:45:02 +0100203 * @brief Performs the ASR pipeline.
204 * @param[in,out] ctx Pointer to the application context object.
205 * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin
206 * and how much data to process.
207 * @return true if pipeline executed without failure.
208 **/
209 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput)
210 {
211 auto& asrModel = ctx.Get<Model&>("asrModel");
212 auto& profiler = ctx.Get<Profiler&>("profiler");
213 auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength");
214 auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride");
215 auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold");
216 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
217
alexander3c798932021-03-26 21:42:19 +0000218 constexpr uint32_t dataPsnTxtInfStartX = 20;
219 constexpr uint32_t dataPsnTxtInfStartY = 40;
220
alexander3c798932021-03-26 21:42:19 +0000221 if (!asrModel.IsInited()) {
222 printf_err("ASR model has not been initialised\n");
223 return false;
224 }
225
Richard Burton4e002792022-05-04 09:45:02 +0100226 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000227
Richard Burton4e002792022-05-04 09:45:02 +0100228 /* Get Input and Output tensors for pre/post processing. */
alexander3c798932021-03-26 21:42:19 +0000229 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
230 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
alexander3c798932021-03-26 21:42:19 +0000231
Richard Burton4e002792022-05-04 09:45:02 +0100232 /* Get input shape. Dimensions of the tensor should have been verified by
233 * the callee. */
234 TfLiteIntArray* inputShape = asrModel.GetInputShape(0);
alexander3c798932021-03-26 21:42:19 +0000235
Richard Burton4e002792022-05-04 09:45:02 +0100236
237 const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +0000238 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
239
240 /* Make sure the input tensor supports the above context and inner lengths. */
241 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100242 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
243 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000244 return false;
245 }
246
247 /* Audio data stride corresponds to inputInnerLen feature vectors. */
Richard Burton4e002792022-05-04 09:45:02 +0100248 const uint32_t asrAudioDataWindowLen = (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen);
249 const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride;
250 const float asrAudioParamsSecondsPerSample = 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000251
252 /* Get the remaining audio buffer and respective size from KWS results. */
253 const int16_t* audioArr = kwsOutput.asrAudioStart;
254 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
255
256 /* Audio clip must have enough samples to produce 1 MFCC feature. */
257 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
Richard Burton4e002792022-05-04 09:45:02 +0100258 if (audioArrSize < asrMfccFrameLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100259 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
Richard Burton4e002792022-05-04 09:45:02 +0100260 asrMfccFrameLen);
alexander3c798932021-03-26 21:42:19 +0000261 return false;
262 }
263
264 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100265 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000266 audioBuffer.data(),
267 audioBuffer.size(),
Richard Burton4e002792022-05-04 09:45:02 +0100268 asrAudioDataWindowLen,
269 asrAudioDataWindowStride);
alexander3c798932021-03-26 21:42:19 +0000270
271 /* Declare a container for results. */
Richard Burton4e002792022-05-04 09:45:02 +0100272 std::vector<asr::AsrResult> asrResults;
alexander3c798932021-03-26 21:42:19 +0000273
274 /* Display message on the LCD - inference running. */
275 std::string str_inf{"Running ASR inference... "};
Richard Burton4e002792022-05-04 09:45:02 +0100276 hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100277 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000278
Richard Burton4e002792022-05-04 09:45:02 +0100279 size_t asrInferenceWindowLen = asrAudioDataWindowLen;
alexander3c798932021-03-26 21:42:19 +0000280
Richard Burton4e002792022-05-04 09:45:02 +0100281 /* Set up pre and post-processing objects. */
282 AsrPreProcess asrPreProcess = AsrPreProcess(asrInputTensor, arm::app::Wav2LetterModel::ms_numMfccFeatures,
283 inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
284 asrMfccFrameLen, asrMfccFrameStride);
285
286 std::vector<ClassificationResult> singleInfResult;
287 const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen);
288 AsrPostProcess asrPostProcess = AsrPostProcess(
289 asrOutputTensor, ctx.Get<AsrClassifier&>("asrClassifier"),
290 ctx.Get<std::vector<std::string>&>("asrLabels"),
291 singleInfResult, outputCtxLen,
292 Wav2LetterModel::ms_blankTokenIdx, Wav2LetterModel::ms_outputRowsIdx
293 );
alexander3c798932021-03-26 21:42:19 +0000294 /* Start sliding through audio clip. */
295 while (audioDataSlider.HasNext()) {
296
297 /* If not enough audio see how much can be sent for processing. */
298 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
Richard Burton4e002792022-05-04 09:45:02 +0100299 if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) {
alexander3c798932021-03-26 21:42:19 +0000300 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
301 }
302
303 const int16_t* asrInferenceWindow = audioDataSlider.Next();
304
305 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
306 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
307
Richard Burton4e002792022-05-04 09:45:02 +0100308 /* Run the pre-processing, inference and post-processing. */
309 if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) {
310 printf_err("ASR pre-processing failed.");
311 return false;
312 }
alexander3c798932021-03-26 21:42:19 +0000313
alexander3c798932021-03-26 21:42:19 +0000314 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100315 if (!RunInference(asrModel, profiler)) {
316 printf_err("ASR inference failed\n");
317 return false;
318 }
alexander3c798932021-03-26 21:42:19 +0000319
Richard Burton4e002792022-05-04 09:45:02 +0100320 /* Post processing needs to know if we are on the last audio window. */
321 asrPostProcess.m_lastIteration = !audioDataSlider.HasNext();
322 if (!asrPostProcess.DoPostProcess()) {
323 printf_err("ASR post-processing failed.");
324 return false;
325 }
alexander3c798932021-03-26 21:42:19 +0000326
327 /* Get results. */
328 std::vector<ClassificationResult> asrClassificationResult;
Richard Burton4e002792022-05-04 09:45:02 +0100329 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier");
alexander3c798932021-03-26 21:42:19 +0000330 asrClassifier.GetClassificationResults(
331 asrOutputTensor, asrClassificationResult,
Richard Burton4e002792022-05-04 09:45:02 +0100332 ctx.Get<std::vector<std::string>&>("asrLabels"), 1);
alexander3c798932021-03-26 21:42:19 +0000333
334 asrResults.emplace_back(asr::AsrResult(asrClassificationResult,
335 (audioDataSlider.Index() *
336 asrAudioParamsSecondsPerSample *
Richard Burton4e002792022-05-04 09:45:02 +0100337 asrAudioDataWindowStride),
alexander3c798932021-03-26 21:42:19 +0000338 audioDataSlider.Index(), asrScoreThreshold));
339
340#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100341 armDumpTensor(asrOutputTensor, asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
alexander3c798932021-03-26 21:42:19 +0000342#endif /* VERIFY_TEST_OUTPUT */
343
344 /* Erase */
345 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100346 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000347 str_inf.c_str(), str_inf.size(),
348 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
349 }
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100350 if (!PresentInferenceResult(asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000351 return false;
352 }
353
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100354 profiler.PrintProfilingResult();
355
alexander3c798932021-03-26 21:42:19 +0000356 return true;
357 }
358
Richard Burton4e002792022-05-04 09:45:02 +0100359 /* KWS and ASR inference handler. */
alexander3c798932021-03-26 21:42:19 +0000360 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
361 {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100362 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000363
364 /* If the request has a valid size, set the audio index. */
365 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100366 if (!SetAppCtxIfmIdx(ctx, clipIndex,"kws_asr")) {
alexander3c798932021-03-26 21:42:19 +0000367 return false;
368 }
369 }
370
371 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
372
373 do {
374 KWSOutput kwsOutput = doKws(ctx);
375 if (!kwsOutput.executionSuccess) {
Richard Burton4e002792022-05-04 09:45:02 +0100376 printf_err("KWS failed\n");
alexander3c798932021-03-26 21:42:19 +0000377 return false;
378 }
379
380 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
Richard Burton4e002792022-05-04 09:45:02 +0100381 info("Trigger keyword spotted\n");
alexander3c798932021-03-26 21:42:19 +0000382 if(!doAsr(ctx, kwsOutput)) {
Richard Burton4e002792022-05-04 09:45:02 +0100383 printf_err("ASR failed\n");
alexander3c798932021-03-26 21:42:19 +0000384 return false;
385 }
386 }
387
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100388 IncrementAppCtxIfmIdx(ctx,"kws_asr");
alexander3c798932021-03-26 21:42:19 +0000389
390 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
391
392 return true;
393 }
394
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100395 static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000396 {
397 constexpr uint32_t dataPsnTxtStartX1 = 20;
398 constexpr uint32_t dataPsnTxtStartY1 = 30;
399 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
400
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100401 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000402
403 /* Display each result. */
404 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
405
Richard Burton4e002792022-05-04 09:45:02 +0100406 for (auto & result : results) {
alexander3c798932021-03-26 21:42:19 +0000407 std::string topKeyword{"<none>"};
408 float score = 0.f;
409
Richard Burton4e002792022-05-04 09:45:02 +0100410 if (!result.m_resultVec.empty()) {
411 topKeyword = result.m_resultVec[0].m_label;
412 score = result.m_resultVec[0].m_normalisedVal;
alexander3c798932021-03-26 21:42:19 +0000413 }
414
415 std::string resultStr =
Richard Burton4e002792022-05-04 09:45:02 +0100416 std::string{"@"} + std::to_string(result.m_timeStamp) +
alexander3c798932021-03-26 21:42:19 +0000417 std::string{"s: "} + topKeyword + std::string{" ("} +
418 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
419
Richard Burton4e002792022-05-04 09:45:02 +0100420 hal_lcd_display_text(resultStr.c_str(), resultStr.size(),
421 dataPsnTxtStartX1, rowIdx1, 0);
alexander3c798932021-03-26 21:42:19 +0000422 rowIdx1 += dataPsnTxtYIncr;
423
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100424 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
Richard Burton4e002792022-05-04 09:45:02 +0100425 result.m_timeStamp, result.m_inferenceNumber,
426 result.m_threshold);
427 for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100428 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n", j,
Richard Burton4e002792022-05-04 09:45:02 +0100429 result.m_resultVec[j].m_label.c_str(),
430 result.m_resultVec[j].m_normalisedVal);
alexander3c798932021-03-26 21:42:19 +0000431 }
432 }
433
434 return true;
435 }
436
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100437 static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000438 {
439 constexpr uint32_t dataPsnTxtStartX1 = 20;
440 constexpr uint32_t dataPsnTxtStartY1 = 80;
441 constexpr bool allow_multiple_lines = true;
442
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100443 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000444
445 /* Results from multiple inferences should be combined before processing. */
446 std::vector<arm::app::ClassificationResult> combinedResults;
447 for (auto& result : results) {
448 combinedResults.insert(combinedResults.end(),
449 result.m_resultVec.begin(),
450 result.m_resultVec.end());
451 }
452
453 for (auto& result : results) {
454 /* Get the final result string using the decoder. */
455 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
456
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100457 info("Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber,
alexander3c798932021-03-26 21:42:19 +0000458 infResultStr.c_str());
459 }
460
461 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
462
Richard Burton4e002792022-05-04 09:45:02 +0100463 hal_lcd_display_text(finalResultStr.c_str(), finalResultStr.size(),
464 dataPsnTxtStartX1, dataPsnTxtStartY1, allow_multiple_lines);
alexander3c798932021-03-26 21:42:19 +0000465
466 info("Final result: %s\n", finalResultStr.c_str());
467 return true;
468 }
469
alexander3c798932021-03-26 21:42:19 +0000470} /* namespace app */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100471} /* namespace arm */