blob: d13a03a34965c9c8fd379075dc9f0d4f1c021d1a [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Kshitij Sisodia2ea46232022-12-19 16:37:33 +00002 * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates
3 * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
alexander3c798932021-03-26 21:42:19 +00004 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
alexander3c798932021-03-26 21:42:19 +000019#include "AsrClassifier.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000020#include "AsrResult.hpp"
alexander3c798932021-03-26 21:42:19 +000021#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000022#include "ImageUtils.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000023#include "InputFiles.hpp"
alexander3c798932021-03-26 21:42:19 +000024#include "OutputDecode.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000025#include "UseCaseCommonUtils.hpp"
26#include "Wav2LetterModel.hpp"
27#include "Wav2LetterPostprocess.hpp"
28#include "Wav2LetterPreprocess.hpp"
29#include "hal.h"
alexander31ae9f02022-02-10 16:15:54 +000030#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000031
32namespace arm {
33namespace app {
34
35 /**
Richard Burtonb40ecf82022-04-22 16:14:57 +010036 * @brief Presents ASR inference results.
37 * @param[in] results Vector of ASR classification results to be displayed.
38 * @return true if successful, false otherwise.
alexander3c798932021-03-26 21:42:19 +000039 **/
Richard Burtonc2911442022-04-22 09:08:21 +010040 static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results);
alexander3c798932021-03-26 21:42:19 +000041
Richard Burtonc2911442022-04-22 09:08:21 +010042 /* ASR inference handler. */
alexander3c798932021-03-26 21:42:19 +000043 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
44 {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000045 auto& model = ctx.Get<Model&>("model");
46 auto& profiler = ctx.Get<Profiler&>("profiler");
47 auto mfccFrameLen = ctx.Get<uint32_t>("frameLength");
Richard Burtonc2911442022-04-22 09:08:21 +010048 auto mfccFrameStride = ctx.Get<uint32_t>("frameStride");
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000049 auto scoreThreshold = ctx.Get<float>("scoreThreshold");
50 auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
alexander3c798932021-03-26 21:42:19 +000051 /* If the request has a valid size, set the audio index. */
52 if (clipIndex < NUMBER_OF_FILES) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000053 if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) {
alexander3c798932021-03-26 21:42:19 +000054 return false;
55 }
56 }
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000057 auto initialClipIdx = ctx.Get<uint32_t>("clipIndex");
Richard Burtonc2911442022-04-22 09:08:21 +010058 constexpr uint32_t dataPsnTxtInfStartX = 20;
59 constexpr uint32_t dataPsnTxtInfStartY = 40;
alexander3c798932021-03-26 21:42:19 +000060
alexander3c798932021-03-26 21:42:19 +000061 if (!model.IsInited()) {
62 printf_err("Model is not initialised! Terminating processing.\n");
63 return false;
64 }
65
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000066 TfLiteTensor* inputTensor = model.GetInputTensor(0);
Richard Burtonb40ecf82022-04-22 16:14:57 +010067 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
68
Richard Burtonc2911442022-04-22 09:08:21 +010069 /* Get input shape. Dimensions of the tensor should have been verified by
alexander3c798932021-03-26 21:42:19 +000070 * the callee. */
Richard Burtonc2911442022-04-22 09:08:21 +010071 TfLiteIntArray* inputShape = model.GetInputShape(0);
alexander3c798932021-03-26 21:42:19 +000072
Richard Burtonc2911442022-04-22 09:08:21 +010073 const uint32_t inputRowsSize = inputShape->data[Wav2LetterModel::ms_inputRowsIdx];
74 const uint32_t inputInnerLen = inputRowsSize - (2 * inputCtxLen);
alexander3c798932021-03-26 21:42:19 +000075
76 /* Audio data stride corresponds to inputInnerLen feature vectors. */
Richard Burtonc2911442022-04-22 09:08:21 +010077 const uint32_t audioDataWindowLen = (inputRowsSize - 1) * mfccFrameStride + (mfccFrameLen);
78 const uint32_t audioDataWindowStride = inputInnerLen * mfccFrameStride;
alexander3c798932021-03-26 21:42:19 +000079
Richard Burtonc2911442022-04-22 09:08:21 +010080 /* NOTE: This is only used for time stamp calculation. */
81 const float secondsPerSample = (1.0 / audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
alexander3c798932021-03-26 21:42:19 +000082
Richard Burtonc2911442022-04-22 09:08:21 +010083 /* Set up pre and post-processing objects. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000084 AsrPreProcess preProcess = AsrPreProcess(inputTensor,
85 Wav2LetterModel::ms_numMfccFeatures,
Richard Burtonb40ecf82022-04-22 16:14:57 +010086 inputShape->data[Wav2LetterModel::ms_inputRowsIdx],
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000087 mfccFrameLen,
88 mfccFrameStride);
alexander3c798932021-03-26 21:42:19 +000089
Richard Burtonc2911442022-04-22 09:08:21 +010090 std::vector<ClassificationResult> singleInfResult;
Richard Burtonb40ecf82022-04-22 16:14:57 +010091 const uint32_t outputCtxLen = AsrPostProcess::GetOutputContextLen(model, inputCtxLen);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000092 AsrPostProcess postProcess = AsrPostProcess(outputTensor,
93 ctx.Get<AsrClassifier&>("classifier"),
94 ctx.Get<std::vector<std::string>&>("labels"),
95 singleInfResult,
96 outputCtxLen,
97 Wav2LetterModel::ms_blankTokenIdx,
98 Wav2LetterModel::ms_outputRowsIdx);
Richard Burtonc2911442022-04-22 09:08:21 +010099
alexander3c798932021-03-26 21:42:19 +0000100 /* Loop to process audio clips. */
101 do {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100102 hal_lcd_clear(COLOR_BLACK);
Richard Burton9b8d67a2021-12-10 12:32:51 +0000103
alexander3c798932021-03-26 21:42:19 +0000104 /* Get current audio clip index. */
105 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
106
107 /* Get the current audio buffer and respective size. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000108 const int16_t* audioArr = GetAudioArray(currentIndex);
109 const uint32_t audioArrSize = GetAudioArraySize(currentIndex);
alexander3c798932021-03-26 21:42:19 +0000110
111 if (!audioArr) {
Richard Burtonc2911442022-04-22 09:08:21 +0100112 printf_err("Invalid audio array pointer.\n");
alexander3c798932021-03-26 21:42:19 +0000113 return false;
114 }
115
Richard Burtonc2911442022-04-22 09:08:21 +0100116 /* Audio clip needs enough samples to produce at least 1 MFCC feature. */
117 if (audioArrSize < mfccFrameLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100118 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
Richard Burtonc2911442022-04-22 09:08:21 +0100119 mfccFrameLen);
alexander3c798932021-03-26 21:42:19 +0000120 return false;
121 }
122
Richard Burtonc2911442022-04-22 09:08:21 +0100123 /* Creating a sliding window through the whole audio clip. */
alexander80eecfb2021-07-06 19:47:59 +0100124 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000125 audioArr, audioArrSize, audioDataWindowLen, audioDataWindowStride);
alexander3c798932021-03-26 21:42:19 +0000126
Richard Burtonc2911442022-04-22 09:08:21 +0100127 /* Declare a container for final results. */
128 std::vector<asr::AsrResult> finalResults;
alexander3c798932021-03-26 21:42:19 +0000129
130 /* Display message on the LCD - inference running. */
131 std::string str_inf{"Running inference... "};
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000132 hal_lcd_display_text(
133 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
alexander3c798932021-03-26 21:42:19 +0000134
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000135 info("Running inference on audio clip %" PRIu32 " => %s\n",
136 currentIndex,
137 GetFilename(currentIndex));
alexander3c798932021-03-26 21:42:19 +0000138
Richard Burtonc2911442022-04-22 09:08:21 +0100139 size_t inferenceWindowLen = audioDataWindowLen;
alexander3c798932021-03-26 21:42:19 +0000140
141 /* Start sliding through audio clip. */
142 while (audioDataSlider.HasNext()) {
143
Richard Burtonc2911442022-04-22 09:08:21 +0100144 /* If not enough audio, see how much can be sent for processing. */
alexander3c798932021-03-26 21:42:19 +0000145 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
Richard Burtonc2911442022-04-22 09:08:21 +0100146 if (nextStartIndex + audioDataWindowLen > audioArrSize) {
alexander3c798932021-03-26 21:42:19 +0000147 inferenceWindowLen = audioArrSize - nextStartIndex;
148 }
149
150 const int16_t* inferenceWindow = audioDataSlider.Next();
151
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000152 info("Inference %zu/%zu\n",
153 audioDataSlider.Index() + 1,
alexander3c798932021-03-26 21:42:19 +0000154 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
155
Richard Burtonc2911442022-04-22 09:08:21 +0100156 /* Run the pre-processing, inference and post-processing. */
Richard Burtonb40ecf82022-04-22 16:14:57 +0100157 if (!preProcess.DoPreProcess(inferenceWindow, inferenceWindowLen)) {
158 printf_err("Pre-processing failed.");
Richard Burtonc2911442022-04-22 09:08:21 +0100159 return false;
160 }
Richard Burtonc2911442022-04-22 09:08:21 +0100161
Richard Burtonb40ecf82022-04-22 16:14:57 +0100162 if (!RunInference(model, profiler)) {
163 printf_err("Inference failed.");
164 return false;
165 }
166
167 /* Post processing needs to know if we are on the last audio window. */
Richard Burtonc2911442022-04-22 09:08:21 +0100168 postProcess.m_lastIteration = !audioDataSlider.HasNext();
Richard Burtonb40ecf82022-04-22 16:14:57 +0100169 if (!postProcess.DoPostProcess()) {
170 printf_err("Post-processing failed.");
alexander27b62d92021-05-04 20:46:08 +0100171 return false;
172 }
alexander3c798932021-03-26 21:42:19 +0000173
Richard Burtonc2911442022-04-22 09:08:21 +0100174 /* Add results from this window to our final results vector. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000175 finalResults.emplace_back(asr::AsrResult(
176 singleInfResult,
177 (audioDataSlider.Index() * secondsPerSample * audioDataWindowStride),
178 audioDataSlider.Index(),
179 scoreThreshold));
alexander3c798932021-03-26 21:42:19 +0000180
181#if VERIFY_TEST_OUTPUT
Richard Burtonc2911442022-04-22 09:08:21 +0100182 armDumpTensor(outputTensor,
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000183 outputTensor->dims->data[Wav2LetterModel::ms_outputColsIdx]);
184#endif /* VERIFY_TEST_OUTPUT */
Richard Burtonc2911442022-04-22 09:08:21 +0100185 } /* while (audioDataSlider.HasNext()) */
alexander3c798932021-03-26 21:42:19 +0000186
187 /* Erase. */
188 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000189 hal_lcd_display_text(
190 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
alexander3c798932021-03-26 21:42:19 +0000191
Richard Burtonc2911442022-04-22 09:08:21 +0100192 ctx.Set<std::vector<asr::AsrResult>>("results", finalResults);
alexander3c798932021-03-26 21:42:19 +0000193
Richard Burtonc2911442022-04-22 09:08:21 +0100194 if (!PresentInferenceResult(finalResults)) {
alexander3c798932021-03-26 21:42:19 +0000195 return false;
196 }
197
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100198 profiler.PrintProfilingResult();
199
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000200 IncrementAppCtxIfmIdx(ctx, "clipIndex");
alexander3c798932021-03-26 21:42:19 +0000201
Richard Burtonc2911442022-04-22 09:08:21 +0100202 } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx);
alexander3c798932021-03-26 21:42:19 +0000203
204 return true;
205 }
206
Richard Burtonc2911442022-04-22 09:08:21 +0100207 static bool PresentInferenceResult(const std::vector<asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000208 {
209 constexpr uint32_t dataPsnTxtStartX1 = 20;
210 constexpr uint32_t dataPsnTxtStartY1 = 60;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000211 constexpr bool allow_multiple_lines = true;
alexander3c798932021-03-26 21:42:19 +0000212
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100213 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000214
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100215 info("Final results:\n");
216 info("Total number of inferences: %zu\n", results.size());
alexander3c798932021-03-26 21:42:19 +0000217 /* Results from multiple inferences should be combined before processing. */
Richard Burtonc2911442022-04-22 09:08:21 +0100218 std::vector<ClassificationResult> combinedResults;
219 for (const auto& result : results) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000220 combinedResults.insert(
221 combinedResults.end(), result.m_resultVec.begin(), result.m_resultVec.end());
alexander3c798932021-03-26 21:42:19 +0000222 }
223
224 /* Get each inference result string using the decoder. */
Richard Burtonc2911442022-04-22 09:08:21 +0100225 for (const auto& result : results) {
alexander3c798932021-03-26 21:42:19 +0000226 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
227
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100228 info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n",
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000229 result.m_timeStamp,
230 result.m_inferenceNumber,
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100231 infResultStr.c_str());
alexander3c798932021-03-26 21:42:19 +0000232 }
233
234 /* Get the decoded result for the combined result. */
235 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
236
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000237 hal_lcd_display_text(finalResultStr.c_str(),
238 finalResultStr.size(),
239 dataPsnTxtStartX1,
240 dataPsnTxtStartY1,
241 allow_multiple_lines);
alexander3c798932021-03-26 21:42:19 +0000242
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100243 info("Complete recognition: %s\n", finalResultStr.c_str());
alexander3c798932021-03-26 21:42:19 +0000244 return true;
245 }
246
247} /* namespace app */
248} /* namespace arm */