blob: 420f72539c2082c34ec0ae4905f25d9020f6005e [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtoned35a6f2022-02-14 11:55:35 +00002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "InputFiles.hpp"
20#include "AsrClassifier.hpp"
21#include "Wav2LetterModel.hpp"
22#include "hal.h"
23#include "Wav2LetterMfcc.hpp"
24#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000025#include "ImageUtils.hpp"
alexander3c798932021-03-26 21:42:19 +000026#include "UseCaseCommonUtils.hpp"
27#include "AsrResult.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "Wav2LetterPostprocess.hpp"
30#include "OutputDecode.hpp"
alexander31ae9f02022-02-10 16:15:54 +000031#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000032
33namespace arm {
34namespace app {
35
36 /**
alexander3c798932021-03-26 21:42:19 +000037 * @brief Presents inference results using the data presentation
38 * object.
alexander3c798932021-03-26 21:42:19 +000039 * @param[in] results Vector of classification results to be displayed.
alexander3c798932021-03-26 21:42:19 +000040 * @return true if successful, false otherwise.
41 **/
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010042 static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000043
44 /* Audio inference classification handler. */
45 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
46 {
47 constexpr uint32_t dataPsnTxtInfStartX = 20;
48 constexpr uint32_t dataPsnTxtInfStartY = 40;
49
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010050 hal_lcd_clear(COLOR_BLACK);
alexander3c798932021-03-26 21:42:19 +000051
Isabella Gottardi8df12f32021-04-07 17:15:31 +010052 auto& profiler = ctx.Get<Profiler&>("profiler");
53
alexander3c798932021-03-26 21:42:19 +000054 /* If the request has a valid size, set the audio index. */
55 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010056 if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
alexander3c798932021-03-26 21:42:19 +000057 return false;
58 }
59 }
60
61 /* Get model reference. */
62 auto& model = ctx.Get<Model&>("model");
63 if (!model.IsInited()) {
64 printf_err("Model is not initialised! Terminating processing.\n");
65 return false;
66 }
67
68 /* Get score threshold to be applied for the classifier (post-inference). */
69 auto scoreThreshold = ctx.Get<float>("scoreThreshold");
70
71 /* Get tensors. Dimensions of the tensor should have been verified by
72 * the callee. */
73 TfLiteTensor* inputTensor = model.GetInputTensor(0);
74 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
75 const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
76
77 /* Populate MFCC related parameters. */
78 auto mfccParamsWinLen = ctx.Get<uint32_t>("frameLength");
79 auto mfccParamsWinStride = ctx.Get<uint32_t>("frameStride");
80
81 /* Populate ASR inference context and inner lengths for input. */
82 auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
83 const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
84
85 /* Audio data stride corresponds to inputInnerLen feature vectors. */
86 const uint32_t audioParamsWinLen = (inputRows - 1) * mfccParamsWinStride + (mfccParamsWinLen);
87 const uint32_t audioParamsWinStride = inputInnerLen * mfccParamsWinStride;
88 const float audioParamsSecondsPerSample = (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
89
90 /* Get pre/post-processing objects. */
91 auto& prep = ctx.Get<audio::asr::Preprocess&>("preprocess");
92 auto& postp = ctx.Get<audio::asr::Postprocess&>("postprocess");
93
94 /* Set default reduction axis for post-processing. */
95 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
96
97 /* Audio clip start index. */
98 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
99
100 /* Loop to process audio clips. */
101 do {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100102 hal_lcd_clear(COLOR_BLACK);
Richard Burton9b8d67a2021-12-10 12:32:51 +0000103
alexander3c798932021-03-26 21:42:19 +0000104 /* Get current audio clip index. */
105 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
106
107 /* Get the current audio buffer and respective size. */
108 const int16_t* audioArr = get_audio_array(currentIndex);
109 const uint32_t audioArrSize = get_audio_array_size(currentIndex);
110
111 if (!audioArr) {
112 printf_err("Invalid audio array pointer\n");
113 return false;
114 }
115
116 /* Audio clip must have enough samples to produce 1 MFCC feature. */
117 if (audioArrSize < mfccParamsWinLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100118 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
119 mfccParamsWinLen);
alexander3c798932021-03-26 21:42:19 +0000120 return false;
121 }
122
123 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100124 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000125 audioArr,
126 audioArrSize,
127 audioParamsWinLen,
128 audioParamsWinStride);
129
130 /* Declare a container for results. */
131 std::vector<arm::app::asr::AsrResult> results;
132
133 /* Display message on the LCD - inference running. */
134 std::string str_inf{"Running inference... "};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100135 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000136 str_inf.c_str(), str_inf.size(),
137 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
138
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100139 info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex,
alexander3c798932021-03-26 21:42:19 +0000140 get_filename(currentIndex));
141
142 size_t inferenceWindowLen = audioParamsWinLen;
143
144 /* Start sliding through audio clip. */
145 while (audioDataSlider.HasNext()) {
146
147 /* If not enough audio see how much can be sent for processing. */
148 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
149 if (nextStartIndex + audioParamsWinLen > audioArrSize) {
150 inferenceWindowLen = audioArrSize - nextStartIndex;
151 }
152
153 const int16_t* inferenceWindow = audioDataSlider.Next();
154
155 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
156 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
157
alexander3c798932021-03-26 21:42:19 +0000158 /* Calculate MFCCs, deltas and populate the input tensor. */
159 prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor);
160
alexander3c798932021-03-26 21:42:19 +0000161 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100162 if (!RunInference(model, profiler)) {
163 return false;
164 }
alexander3c798932021-03-26 21:42:19 +0000165
166 /* Post-process. */
167 postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext());
168
169 /* Get results. */
170 std::vector<ClassificationResult> classificationResult;
171 auto& classifier = ctx.Get<AsrClassifier&>("classifier");
172 classifier.GetClassificationResults(
173 outputTensor, classificationResult,
174 ctx.Get<std::vector<std::string>&>("labels"), 1);
175
176 results.emplace_back(asr::AsrResult(classificationResult,
177 (audioDataSlider.Index() *
178 audioParamsSecondsPerSample *
179 audioParamsWinStride),
180 audioDataSlider.Index(), scoreThreshold));
181
182#if VERIFY_TEST_OUTPUT
183 arm::app::DumpTensor(outputTensor,
184 outputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
185#endif /* VERIFY_TEST_OUTPUT */
186
187 }
188
189 /* Erase. */
190 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100191 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000192 str_inf.c_str(), str_inf.size(),
193 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
194
195 ctx.Set<std::vector<arm::app::asr::AsrResult>>("results", results);
196
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100197 if (!PresentInferenceResult(results)) {
alexander3c798932021-03-26 21:42:19 +0000198 return false;
199 }
200
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100201 profiler.PrintProfilingResult();
202
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100203 IncrementAppCtxIfmIdx(ctx,"clipIndex");
alexander3c798932021-03-26 21:42:19 +0000204
205 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
206
207 return true;
208 }
209
alexander3c798932021-03-26 21:42:19 +0000210
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100211 static bool PresentInferenceResult(const std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000212 {
213 constexpr uint32_t dataPsnTxtStartX1 = 20;
214 constexpr uint32_t dataPsnTxtStartY1 = 60;
215 constexpr bool allow_multiple_lines = true;
216
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100217 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000218
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100219 info("Final results:\n");
220 info("Total number of inferences: %zu\n", results.size());
alexander3c798932021-03-26 21:42:19 +0000221 /* Results from multiple inferences should be combined before processing. */
222 std::vector<arm::app::ClassificationResult> combinedResults;
223 for (auto& result : results) {
224 combinedResults.insert(combinedResults.end(),
225 result.m_resultVec.begin(),
226 result.m_resultVec.end());
227 }
228
229 /* Get each inference result string using the decoder. */
230 for (const auto & result : results) {
231 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
232
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100233 info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n",
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100234 result.m_timeStamp, result.m_inferenceNumber,
235 infResultStr.c_str());
alexander3c798932021-03-26 21:42:19 +0000236 }
237
238 /* Get the decoded result for the combined result. */
239 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
240
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100241 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000242 finalResultStr.c_str(), finalResultStr.size(),
243 dataPsnTxtStartX1, dataPsnTxtStartY1,
244 allow_multiple_lines);
245
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100246 info("Complete recognition: %s\n", finalResultStr.c_str());
alexander3c798932021-03-26 21:42:19 +0000247 return true;
248 }
249
250} /* namespace app */
251} /* namespace arm */