blob: 8a024b7ae7b94544c9dc67f076c4dfe7ec02bca2 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Kshitij Sisodia2ea46232022-12-19 16:37:33 +00002 * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates
3 * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
alexander3c798932021-03-26 21:42:19 +00004 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
alexander3c798932021-03-26 21:42:19 +000019#include "AsrClassifier.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000020#include "AsrResult.hpp"
21#include "AudioUtils.hpp"
22#include "Classifier.hpp"
23#include "ImageUtils.hpp"
24#include "InputFiles.hpp"
25#include "KwsProcessing.hpp"
26#include "KwsResult.hpp"
27#include "MicroNetKwsMfcc.hpp"
28#include "MicroNetKwsModel.hpp"
alexander3c798932021-03-26 21:42:19 +000029#include "OutputDecode.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000030#include "UseCaseCommonUtils.hpp"
31#include "Wav2LetterMfcc.hpp"
32#include "Wav2LetterModel.hpp"
33#include "Wav2LetterPostprocess.hpp"
34#include "Wav2LetterPreprocess.hpp"
35#include "hal.h"
alexander31ae9f02022-02-10 16:15:54 +000036#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000037
alexander3c798932021-03-26 21:42:19 +000038using KwsClassifier = arm::app::Classifier;
39
40namespace arm {
41namespace app {
42
alexander3c798932021-03-26 21:42:19 +000043 struct KWSOutput {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000044 bool executionSuccess = false;
alexander3c798932021-03-26 21:42:19 +000045 const int16_t* asrAudioStart = nullptr;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000046 int32_t asrAudioSamples = 0;
alexander3c798932021-03-26 21:42:19 +000047 };
48
49 /**
Richard Burton4e002792022-05-04 09:45:02 +010050 * @brief Presents KWS inference results.
51 * @param[in] results Vector of KWS classification results to be displayed.
52 * @return true if successful, false otherwise.
alexander3c798932021-03-26 21:42:19 +000053 **/
Richard Burton4e002792022-05-04 09:45:02 +010054 static bool PresentInferenceResult(std::vector<kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000055
56 /**
Richard Burton4e002792022-05-04 09:45:02 +010057 * @brief Presents ASR inference results.
58 * @param[in] results Vector of ASR classification results to be displayed.
59 * @return true if successful, false otherwise.
alexander3c798932021-03-26 21:42:19 +000060 **/
Richard Burton4e002792022-05-04 09:45:02 +010061 static bool PresentInferenceResult(std::vector<asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000062
63 /**
Richard Burton4e002792022-05-04 09:45:02 +010064 * @brief Performs the KWS pipeline.
65 * @param[in,out] ctx pointer to the application context object
66 * @return struct containing pointer to audio data where ASR should begin
67 * and how much data to process.
alexander3c798932021-03-26 21:42:19 +000068 **/
Richard Burton4e002792022-05-04 09:45:02 +010069 static KWSOutput doKws(ApplicationContext& ctx)
70 {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000071 auto& profiler = ctx.Get<Profiler&>("profiler");
72 auto& kwsModel = ctx.Get<Model&>("kwsModel");
Richard Burton4e002792022-05-04 09:45:02 +010073 const auto kwsMfccFrameLength = ctx.Get<int>("kwsFrameLength");
74 const auto kwsMfccFrameStride = ctx.Get<int>("kwsFrameStride");
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000075 const auto kwsScoreThreshold = ctx.Get<float>("kwsScoreThreshold");
alexander3c798932021-03-26 21:42:19 +000076
Richard Burton4e002792022-05-04 09:45:02 +010077 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
78
alexander3c798932021-03-26 21:42:19 +000079 constexpr uint32_t dataPsnTxtInfStartX = 20;
80 constexpr uint32_t dataPsnTxtInfStartY = 40;
81
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000082 constexpr int minTensorDims =
83 static_cast<int>((MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)
84 ? MicroNetKwsModel::ms_inputRowsIdx
85 : MicroNetKwsModel::ms_inputColsIdx);
alexander3c798932021-03-26 21:42:19 +000086
Richard Burton4e002792022-05-04 09:45:02 +010087 /* Output struct from doing KWS. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000088 KWSOutput output{};
alexander3c798932021-03-26 21:42:19 +000089
alexander3c798932021-03-26 21:42:19 +000090 if (!kwsModel.IsInited()) {
91 printf_err("KWS model has not been initialised\n");
92 return output;
93 }
94
Richard Burton4e002792022-05-04 09:45:02 +010095 /* Get Input and Output tensors for pre/post processing. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000096 TfLiteTensor* kwsInputTensor = kwsModel.GetInputTensor(0);
Richard Burton4e002792022-05-04 09:45:02 +010097 TfLiteTensor* kwsOutputTensor = kwsModel.GetOutputTensor(0);
alexander3c798932021-03-26 21:42:19 +000098 if (!kwsInputTensor->dims) {
99 printf_err("Invalid input tensor dims\n");
100 return output;
101 } else if (kwsInputTensor->dims->size < minTensorDims) {
102 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
103 return output;
104 }
105
Richard Burton4e002792022-05-04 09:45:02 +0100106 /* Get input shape for feature extraction. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000107 TfLiteIntArray* inputShape = kwsModel.GetInputShape(0);
Richard Burton4e002792022-05-04 09:45:02 +0100108 const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx];
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000109 const uint32_t numMfccFrames = inputShape->data[MicroNetKwsModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +0000110
111 /* We expect to be sampling 1 second worth of data at a time
112 * NOTE: This is only used for time stamp calculation. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000113 const float kwsAudioParamsSecondsPerSample =
114 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000115
Richard Burton4e002792022-05-04 09:45:02 +0100116 /* Set up pre and post-processing. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000117 KwsPreProcess preProcess = KwsPreProcess(
118 kwsInputTensor, numMfccFeatures, numMfccFrames, kwsMfccFrameLength, kwsMfccFrameStride);
alexander3c798932021-03-26 21:42:19 +0000119
Richard Burton4e002792022-05-04 09:45:02 +0100120 std::vector<ClassificationResult> singleInfResult;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000121 KwsPostProcess postProcess = KwsPostProcess(kwsOutputTensor,
122 ctx.Get<KwsClassifier&>("kwsClassifier"),
Richard Burton4e002792022-05-04 09:45:02 +0100123 ctx.Get<std::vector<std::string>&>("kwsLabels"),
124 singleInfResult);
alexander3c798932021-03-26 21:42:19 +0000125
126 /* Creating a sliding window through the whole audio clip. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000127 auto audioDataSlider = audio::SlidingWindow<const int16_t>(GetAudioArray(currentIndex),
128 GetAudioArraySize(currentIndex),
129 preProcess.m_audioDataWindowSize,
130 preProcess.m_audioDataStride);
alexander3c798932021-03-26 21:42:19 +0000131
Richard Burton4e002792022-05-04 09:45:02 +0100132 /* Declare a container to hold kws results from across the whole audio clip. */
133 std::vector<kws::KwsResult> finalResults;
alexander3c798932021-03-26 21:42:19 +0000134
135 /* Display message on the LCD - inference running. */
alexander3c798932021-03-26 21:42:19 +0000136 std::string str_inf{"Running KWS inference... "};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100137 hal_lcd_display_text(
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000138 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000139
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100140 info("Running KWS inference on audio clip %" PRIu32 " => %s\n",
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000141 currentIndex,
142 GetFilename(currentIndex));
alexander3c798932021-03-26 21:42:19 +0000143
144 /* Start sliding through audio clip. */
145 while (audioDataSlider.HasNext()) {
146 const int16_t* inferenceWindow = audioDataSlider.Next();
147
Richard Burton4e002792022-05-04 09:45:02 +0100148 /* Run the pre-processing, inference and post-processing. */
Richard Burtonec5e99b2022-10-05 11:00:37 +0100149 if (!preProcess.DoPreProcess(inferenceWindow, audioDataSlider.Index())) {
Richard Burton4e002792022-05-04 09:45:02 +0100150 printf_err("KWS Pre-processing failed.");
151 return output;
152 }
alexander3c798932021-03-26 21:42:19 +0000153
Richard Burton4e002792022-05-04 09:45:02 +0100154 if (!RunInference(kwsModel, profiler)) {
155 printf_err("KWS Inference failed.");
156 return output;
157 }
158
159 if (!postProcess.DoPostProcess()) {
160 printf_err("KWS Post-processing failed.");
161 return output;
alexander3c798932021-03-26 21:42:19 +0000162 }
163
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000164 info("Inference %zu/%zu\n",
165 audioDataSlider.Index() + 1,
alexander3c798932021-03-26 21:42:19 +0000166 audioDataSlider.TotalStrides() + 1);
167
Richard Burton4e002792022-05-04 09:45:02 +0100168 /* Add results from this window to our final results vector. */
169 finalResults.emplace_back(
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000170 kws::KwsResult(singleInfResult,
171 audioDataSlider.Index() * kwsAudioParamsSecondsPerSample *
172 preProcess.m_audioDataStride,
173 audioDataSlider.Index(),
174 kwsScoreThreshold));
alexander3c798932021-03-26 21:42:19 +0000175
Richard Burton4e002792022-05-04 09:45:02 +0100176 /* Break out when trigger keyword is detected. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000177 if (singleInfResult[0].m_label == ctx.Get<const std::string&>("triggerKeyword") &&
178 singleInfResult[0].m_normalisedVal > kwsScoreThreshold) {
Richard Burton4e002792022-05-04 09:45:02 +0100179 output.asrAudioStart = inferenceWindow + preProcess.m_audioDataWindowSize;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000180 output.asrAudioSamples =
181 GetAudioArraySize(currentIndex) -
182 (audioDataSlider.NextWindowStartIndex() - preProcess.m_audioDataStride +
183 preProcess.m_audioDataWindowSize);
alexander3c798932021-03-26 21:42:19 +0000184 break;
185 }
186
187#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100188 DumpTensor(kwsOutputTensor);
alexander3c798932021-03-26 21:42:19 +0000189#endif /* VERIFY_TEST_OUTPUT */
190
191 } /* while (audioDataSlider.HasNext()) */
192
193 /* Erase. */
194 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000195 hal_lcd_display_text(
196 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000197
Richard Burton4e002792022-05-04 09:45:02 +0100198 if (!PresentInferenceResult(finalResults)) {
alexander3c798932021-03-26 21:42:19 +0000199 return output;
200 }
201
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100202 profiler.PrintProfilingResult();
203
alexander3c798932021-03-26 21:42:19 +0000204 output.executionSuccess = true;
205 return output;
206 }
207
208 /**
Richard Burton4e002792022-05-04 09:45:02 +0100209 * @brief Performs the ASR pipeline.
210 * @param[in,out] ctx Pointer to the application context object.
211 * @param[in] kwsOutput Struct containing pointer to audio data where ASR should begin
212 * and how much data to process.
213 * @return true if pipeline executed without failure.
214 **/
215 static bool doAsr(ApplicationContext& ctx, const KWSOutput& kwsOutput)
216 {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000217 auto& asrModel = ctx.Get<Model&>("asrModel");
218 auto& profiler = ctx.Get<Profiler&>("profiler");
219 auto asrMfccFrameLen = ctx.Get<uint32_t>("asrFrameLength");
Richard Burton4e002792022-05-04 09:45:02 +0100220 auto asrMfccFrameStride = ctx.Get<uint32_t>("asrFrameStride");
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000221 auto asrScoreThreshold = ctx.Get<float>("asrScoreThreshold");
222 auto asrInputCtxLen = ctx.Get<uint32_t>("ctxLen");
Richard Burton4e002792022-05-04 09:45:02 +0100223
alexander3c798932021-03-26 21:42:19 +0000224 constexpr uint32_t dataPsnTxtInfStartX = 20;
225 constexpr uint32_t dataPsnTxtInfStartY = 40;
226
alexander3c798932021-03-26 21:42:19 +0000227 if (!asrModel.IsInited()) {
228 printf_err("ASR model has not been initialised\n");
229 return false;
230 }
231
Richard Burton4e002792022-05-04 09:45:02 +0100232 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000233
Richard Burton4e002792022-05-04 09:45:02 +0100234 /* Get Input and Output tensors for pre/post processing. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000235 TfLiteTensor* asrInputTensor = asrModel.GetInputTensor(0);
alexander3c798932021-03-26 21:42:19 +0000236 TfLiteTensor* asrOutputTensor = asrModel.GetOutputTensor(0);
alexander3c798932021-03-26 21:42:19 +0000237
Richard Burton4e002792022-05-04 09:45:02 +0100238 /* Get input shape. Dimensions of the tensor should have been verified by
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000239 * the callee. */
Richard Burton4e002792022-05-04 09:45:02 +0100240 TfLiteIntArray* inputShape = asrModel.GetInputShape(0);
alexander3c798932021-03-26 21:42:19 +0000241
Richard Burton4e002792022-05-04 09:45:02 +0100242 const uint32_t asrInputRows = asrInputTensor->dims->data[Wav2LetterModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +0000243 const uint32_t asrInputInnerLen = asrInputRows - (2 * asrInputCtxLen);
244
245 /* Make sure the input tensor supports the above context and inner lengths. */
246 if (asrInputRows <= 2 * asrInputCtxLen || asrInputRows <= asrInputInnerLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100247 printf_err("ASR input rows not compatible with ctx length %" PRIu32 "\n",
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000248 asrInputCtxLen);
alexander3c798932021-03-26 21:42:19 +0000249 return false;
250 }
251
252 /* Audio data stride corresponds to inputInnerLen feature vectors. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000253 const uint32_t asrAudioDataWindowLen =
254 (asrInputRows - 1) * asrMfccFrameStride + (asrMfccFrameLen);
Richard Burton4e002792022-05-04 09:45:02 +0100255 const uint32_t asrAudioDataWindowStride = asrInputInnerLen * asrMfccFrameStride;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000256 const float asrAudioParamsSecondsPerSample =
257 1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000258
259 /* Get the remaining audio buffer and respective size from KWS results. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000260 const int16_t* audioArr = kwsOutput.asrAudioStart;
alexander3c798932021-03-26 21:42:19 +0000261 const uint32_t audioArrSize = kwsOutput.asrAudioSamples;
262
263 /* Audio clip must have enough samples to produce 1 MFCC feature. */
264 std::vector<int16_t> audioBuffer = std::vector<int16_t>(audioArr, audioArr + audioArrSize);
Richard Burton4e002792022-05-04 09:45:02 +0100265 if (audioArrSize < asrMfccFrameLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100266 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
Richard Burton4e002792022-05-04 09:45:02 +0100267 asrMfccFrameLen);
alexander3c798932021-03-26 21:42:19 +0000268 return false;
269 }
270
271 /* Initialise an audio slider. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000272 auto audioDataSlider =
273 audio::FractionalSlidingWindow<const int16_t>(audioBuffer.data(),
274 audioBuffer.size(),
275 asrAudioDataWindowLen,
276 asrAudioDataWindowStride);
alexander3c798932021-03-26 21:42:19 +0000277
278 /* Declare a container for results. */
Richard Burton4e002792022-05-04 09:45:02 +0100279 std::vector<asr::AsrResult> asrResults;
alexander3c798932021-03-26 21:42:19 +0000280
281 /* Display message on the LCD - inference running. */
282 std::string str_inf{"Running ASR inference... "};
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000283 hal_lcd_display_text(
284 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000285
Richard Burton4e002792022-05-04 09:45:02 +0100286 size_t asrInferenceWindowLen = asrAudioDataWindowLen;
alexander3c798932021-03-26 21:42:19 +0000287
Richard Burton4e002792022-05-04 09:45:02 +0100288 /* Set up pre and post-processing objects. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000289 AsrPreProcess asrPreProcess =
290 AsrPreProcess(asrInputTensor,
291 arm::app::Wav2LetterModel::ms_numMfccFeatures,
292 inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
293 asrMfccFrameLen,
294 asrMfccFrameStride);
Richard Burton4e002792022-05-04 09:45:02 +0100295
296 std::vector<ClassificationResult> singleInfResult;
297 const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(asrModel, asrInputCtxLen);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000298 AsrPostProcess asrPostProcess =
299 AsrPostProcess(asrOutputTensor,
300 ctx.Get<AsrClassifier&>("asrClassifier"),
301 ctx.Get<std::vector<std::string>&>("asrLabels"),
302 singleInfResult,
303 outputCtxLen,
304 Wav2LetterModel::ms_blankTokenIdx,
305 Wav2LetterModel::ms_outputRowsIdx);
alexander3c798932021-03-26 21:42:19 +0000306 /* Start sliding through audio clip. */
307 while (audioDataSlider.HasNext()) {
308
309 /* If not enough audio see how much can be sent for processing. */
310 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
Richard Burton4e002792022-05-04 09:45:02 +0100311 if (nextStartIndex + asrAudioDataWindowLen > audioBuffer.size()) {
alexander3c798932021-03-26 21:42:19 +0000312 asrInferenceWindowLen = audioBuffer.size() - nextStartIndex;
313 }
314
315 const int16_t* asrInferenceWindow = audioDataSlider.Next();
316
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000317 info("Inference %zu/%zu\n",
318 audioDataSlider.Index() + 1,
319 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
alexander3c798932021-03-26 21:42:19 +0000320
Richard Burton4e002792022-05-04 09:45:02 +0100321 /* Run the pre-processing, inference and post-processing. */
322 if (!asrPreProcess.DoPreProcess(asrInferenceWindow, asrInferenceWindowLen)) {
323 printf_err("ASR pre-processing failed.");
324 return false;
325 }
alexander3c798932021-03-26 21:42:19 +0000326
alexander3c798932021-03-26 21:42:19 +0000327 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100328 if (!RunInference(asrModel, profiler)) {
329 printf_err("ASR inference failed\n");
330 return false;
331 }
alexander3c798932021-03-26 21:42:19 +0000332
Richard Burton4e002792022-05-04 09:45:02 +0100333 /* Post processing needs to know if we are on the last audio window. */
334 asrPostProcess.m_lastIteration = !audioDataSlider.HasNext();
335 if (!asrPostProcess.DoPostProcess()) {
336 printf_err("ASR post-processing failed.");
337 return false;
338 }
alexander3c798932021-03-26 21:42:19 +0000339
340 /* Get results. */
341 std::vector<ClassificationResult> asrClassificationResult;
Richard Burton4e002792022-05-04 09:45:02 +0100342 auto& asrClassifier = ctx.Get<AsrClassifier&>("asrClassifier");
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000343 asrClassifier.GetClassificationResults(asrOutputTensor,
344 asrClassificationResult,
345 ctx.Get<std::vector<std::string>&>("asrLabels"),
346 1);
alexander3c798932021-03-26 21:42:19 +0000347
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000348 asrResults.emplace_back(
349 asr::AsrResult(asrClassificationResult,
350 (audioDataSlider.Index() * asrAudioParamsSecondsPerSample *
351 asrAudioDataWindowStride),
352 audioDataSlider.Index(),
353 asrScoreThreshold));
alexander3c798932021-03-26 21:42:19 +0000354
355#if VERIFY_TEST_OUTPUT
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000356 armDumpTensor(asrOutputTensor,
357 asrOutputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
alexander3c798932021-03-26 21:42:19 +0000358#endif /* VERIFY_TEST_OUTPUT */
359
360 /* Erase */
361 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100362 hal_lcd_display_text(
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000363 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000364 }
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100365 if (!PresentInferenceResult(asrResults)) {
alexander3c798932021-03-26 21:42:19 +0000366 return false;
367 }
368
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100369 profiler.PrintProfilingResult();
370
alexander3c798932021-03-26 21:42:19 +0000371 return true;
372 }
373
Richard Burton4e002792022-05-04 09:45:02 +0100374 /* KWS and ASR inference handler. */
alexander3c798932021-03-26 21:42:19 +0000375 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
376 {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100377 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +0000378
379 /* If the request has a valid size, set the audio index. */
380 if (clipIndex < NUMBER_OF_FILES) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000381 if (!SetAppCtxIfmIdx(ctx, clipIndex, "kws_asr")) {
alexander3c798932021-03-26 21:42:19 +0000382 return false;
383 }
384 }
385
386 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
387
388 do {
389 KWSOutput kwsOutput = doKws(ctx);
390 if (!kwsOutput.executionSuccess) {
Richard Burton4e002792022-05-04 09:45:02 +0100391 printf_err("KWS failed\n");
alexander3c798932021-03-26 21:42:19 +0000392 return false;
393 }
394
395 if (kwsOutput.asrAudioStart != nullptr && kwsOutput.asrAudioSamples > 0) {
Richard Burton4e002792022-05-04 09:45:02 +0100396 info("Trigger keyword spotted\n");
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000397 if (!doAsr(ctx, kwsOutput)) {
Richard Burton4e002792022-05-04 09:45:02 +0100398 printf_err("ASR failed\n");
alexander3c798932021-03-26 21:42:19 +0000399 return false;
400 }
401 }
402
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000403 IncrementAppCtxIfmIdx(ctx, "kws_asr");
alexander3c798932021-03-26 21:42:19 +0000404
405 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
406
407 return true;
408 }
409
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100410 static bool PresentInferenceResult(std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000411 {
412 constexpr uint32_t dataPsnTxtStartX1 = 20;
413 constexpr uint32_t dataPsnTxtStartY1 = 30;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000414 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
alexander3c798932021-03-26 21:42:19 +0000415
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100416 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000417
418 /* Display each result. */
419 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
420
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000421 for (auto& result : results) {
alexander3c798932021-03-26 21:42:19 +0000422 std::string topKeyword{"<none>"};
423 float score = 0.f;
424
Richard Burton4e002792022-05-04 09:45:02 +0100425 if (!result.m_resultVec.empty()) {
426 topKeyword = result.m_resultVec[0].m_label;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000427 score = result.m_resultVec[0].m_normalisedVal;
alexander3c798932021-03-26 21:42:19 +0000428 }
429
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000430 std::string resultStr = std::string{"@"} + std::to_string(result.m_timeStamp) +
431 std::string{"s: "} + topKeyword + std::string{" ("} +
432 std::to_string(static_cast<int>(score * 100)) +
433 std::string{"%)"};
alexander3c798932021-03-26 21:42:19 +0000434
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000435 hal_lcd_display_text(
436 resultStr.c_str(), resultStr.size(), dataPsnTxtStartX1, rowIdx1, 0);
alexander3c798932021-03-26 21:42:19 +0000437 rowIdx1 += dataPsnTxtYIncr;
438
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100439 info("For timestamp: %f (inference #: %" PRIu32 "); threshold: %f\n",
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000440 result.m_timeStamp,
441 result.m_inferenceNumber,
Richard Burton4e002792022-05-04 09:45:02 +0100442 result.m_threshold);
443 for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000444 info("\t\tlabel @ %" PRIu32 ": %s, score: %f\n",
445 j,
Richard Burton4e002792022-05-04 09:45:02 +0100446 result.m_resultVec[j].m_label.c_str(),
447 result.m_resultVec[j].m_normalisedVal);
alexander3c798932021-03-26 21:42:19 +0000448 }
449 }
450
451 return true;
452 }
453
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100454 static bool PresentInferenceResult(std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000455 {
456 constexpr uint32_t dataPsnTxtStartX1 = 20;
457 constexpr uint32_t dataPsnTxtStartY1 = 80;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000458 constexpr bool allow_multiple_lines = true;
alexander3c798932021-03-26 21:42:19 +0000459
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100460 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000461
462 /* Results from multiple inferences should be combined before processing. */
463 std::vector<arm::app::ClassificationResult> combinedResults;
464 for (auto& result : results) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000465 combinedResults.insert(
466 combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end());
alexander3c798932021-03-26 21:42:19 +0000467 }
468
469 for (auto& result : results) {
470 /* Get the final result string using the decoder. */
471 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
472
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000473 info(
474 "Result for inf %" PRIu32 ": %s\n", result.m_inferenceNumber, infResultStr.c_str());
alexander3c798932021-03-26 21:42:19 +0000475 }
476
477 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
478
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000479 hal_lcd_display_text(finalResultStr.c_str(),
480 finalResultStr.size(),
481 dataPsnTxtStartX1,
482 dataPsnTxtStartY1,
483 allow_multiple_lines);
alexander3c798932021-03-26 21:42:19 +0000484
485 info("Final result: %s\n", finalResultStr.c_str());
486 return true;
487 }
488
alexander3c798932021-03-26 21:42:19 +0000489} /* namespace app */
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +0100490} /* namespace arm */