blob: 5d3157ad6765e8a86f6772bb39e4287886bface6 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "InputFiles.hpp"
20#include "AsrClassifier.hpp"
21#include "Wav2LetterModel.hpp"
22#include "hal.h"
23#include "Wav2LetterMfcc.hpp"
24#include "AudioUtils.hpp"
25#include "UseCaseCommonUtils.hpp"
26#include "AsrResult.hpp"
27#include "Wav2LetterPreprocess.hpp"
28#include "Wav2LetterPostprocess.hpp"
29#include "OutputDecode.hpp"
30
31namespace arm {
32namespace app {
33
34 /**
35 * @brief Helper function to increment current audio clip index.
36 * @param[in,out] ctx Pointer to the application context object.
37 **/
38 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx);
39
40 /**
41 * @brief Helper function to set the audio clip index.
42 * @param[in,out] ctx Pointer to the application context object.
43 * @param[in] idx Value to be set.
44 * @return true if index is set, false otherwise.
45 **/
46 static bool _SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
47
48 /**
49 * @brief Presents inference results using the data presentation
50 * object.
51 * @param[in] platform Reference to the hal platform object.
52 * @param[in] results Vector of classification results to be displayed.
53 * @param[in] infTimeMs Inference time in milliseconds, if available
54 * otherwise, this can be passed in as 0.
55 * @return true if successful, false otherwise.
56 **/
57 static bool _PresentInferenceResult(
58 hal_platform& platform,
59 const std::vector<arm::app::asr::AsrResult>& results);
60
61 /* Audio inference classification handler. */
62 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
63 {
64 constexpr uint32_t dataPsnTxtInfStartX = 20;
65 constexpr uint32_t dataPsnTxtInfStartY = 40;
66
67 auto& platform = ctx.Get<hal_platform&>("platform");
68 platform.data_psn->clear(COLOR_BLACK);
69
Isabella Gottardi8df12f32021-04-07 17:15:31 +010070 auto& profiler = ctx.Get<Profiler&>("profiler");
71
alexander3c798932021-03-26 21:42:19 +000072 /* If the request has a valid size, set the audio index. */
73 if (clipIndex < NUMBER_OF_FILES) {
74 if (!_SetAppCtxClipIdx(ctx, clipIndex)) {
75 return false;
76 }
77 }
78
79 /* Get model reference. */
80 auto& model = ctx.Get<Model&>("model");
81 if (!model.IsInited()) {
82 printf_err("Model is not initialised! Terminating processing.\n");
83 return false;
84 }
85
86 /* Get score threshold to be applied for the classifier (post-inference). */
87 auto scoreThreshold = ctx.Get<float>("scoreThreshold");
88
89 /* Get tensors. Dimensions of the tensor should have been verified by
90 * the callee. */
91 TfLiteTensor* inputTensor = model.GetInputTensor(0);
92 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
93 const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
94
95 /* Populate MFCC related parameters. */
96 auto mfccParamsWinLen = ctx.Get<uint32_t>("frameLength");
97 auto mfccParamsWinStride = ctx.Get<uint32_t>("frameStride");
98
99 /* Populate ASR inference context and inner lengths for input. */
100 auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
101 const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
102
103 /* Audio data stride corresponds to inputInnerLen feature vectors. */
104 const uint32_t audioParamsWinLen = (inputRows - 1) * mfccParamsWinStride + (mfccParamsWinLen);
105 const uint32_t audioParamsWinStride = inputInnerLen * mfccParamsWinStride;
106 const float audioParamsSecondsPerSample = (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
107
108 /* Get pre/post-processing objects. */
109 auto& prep = ctx.Get<audio::asr::Preprocess&>("preprocess");
110 auto& postp = ctx.Get<audio::asr::Postprocess&>("postprocess");
111
112 /* Set default reduction axis for post-processing. */
113 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
114
115 /* Audio clip start index. */
116 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
117
118 /* Loop to process audio clips. */
119 do {
120 /* Get current audio clip index. */
121 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
122
123 /* Get the current audio buffer and respective size. */
124 const int16_t* audioArr = get_audio_array(currentIndex);
125 const uint32_t audioArrSize = get_audio_array_size(currentIndex);
126
127 if (!audioArr) {
128 printf_err("Invalid audio array pointer\n");
129 return false;
130 }
131
132 /* Audio clip must have enough samples to produce 1 MFCC feature. */
133 if (audioArrSize < mfccParamsWinLen) {
134 printf_err("Not enough audio samples, minimum needed is %u\n", mfccParamsWinLen);
135 return false;
136 }
137
138 /* Initialise an audio slider. */
139 auto audioDataSlider = audio::ASRSlidingWindow<const int16_t>(
140 audioArr,
141 audioArrSize,
142 audioParamsWinLen,
143 audioParamsWinStride);
144
145 /* Declare a container for results. */
146 std::vector<arm::app::asr::AsrResult> results;
147
148 /* Display message on the LCD - inference running. */
149 std::string str_inf{"Running inference... "};
150 platform.data_psn->present_data_text(
151 str_inf.c_str(), str_inf.size(),
152 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
153
154 info("Running inference on audio clip %u => %s\n", currentIndex,
155 get_filename(currentIndex));
156
157 size_t inferenceWindowLen = audioParamsWinLen;
158
159 /* Start sliding through audio clip. */
160 while (audioDataSlider.HasNext()) {
161
162 /* If not enough audio see how much can be sent for processing. */
163 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
164 if (nextStartIndex + audioParamsWinLen > audioArrSize) {
165 inferenceWindowLen = audioArrSize - nextStartIndex;
166 }
167
168 const int16_t* inferenceWindow = audioDataSlider.Next();
169
170 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
171 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
172
alexander3c798932021-03-26 21:42:19 +0000173 /* Calculate MFCCs, deltas and populate the input tensor. */
174 prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor);
175
alexander3c798932021-03-26 21:42:19 +0000176 /* Run inference over this audio clip sliding window. */
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100177 arm::app::RunInference(model, profiler);
alexander3c798932021-03-26 21:42:19 +0000178
179 /* Post-process. */
180 postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext());
181
182 /* Get results. */
183 std::vector<ClassificationResult> classificationResult;
184 auto& classifier = ctx.Get<AsrClassifier&>("classifier");
185 classifier.GetClassificationResults(
186 outputTensor, classificationResult,
187 ctx.Get<std::vector<std::string>&>("labels"), 1);
188
189 results.emplace_back(asr::AsrResult(classificationResult,
190 (audioDataSlider.Index() *
191 audioParamsSecondsPerSample *
192 audioParamsWinStride),
193 audioDataSlider.Index(), scoreThreshold));
194
195#if VERIFY_TEST_OUTPUT
196 arm::app::DumpTensor(outputTensor,
197 outputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
198#endif /* VERIFY_TEST_OUTPUT */
199
200 }
201
202 /* Erase. */
203 str_inf = std::string(str_inf.size(), ' ');
204 platform.data_psn->present_data_text(
205 str_inf.c_str(), str_inf.size(),
206 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
207
208 ctx.Set<std::vector<arm::app::asr::AsrResult>>("results", results);
209
210 if (!_PresentInferenceResult(platform, results)) {
211 return false;
212 }
213
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100214 profiler.PrintProfilingResult();
215
alexander3c798932021-03-26 21:42:19 +0000216 _IncrementAppCtxClipIdx(ctx);
217
218 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
219
220 return true;
221 }
222
223 static void _IncrementAppCtxClipIdx(ApplicationContext& ctx)
224 {
225 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
226
227 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
228 ctx.Set<uint32_t>("clipIndex", 0);
229 return;
230 }
231 ++curAudioIdx;
232 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
233 }
234
235 static bool _SetAppCtxClipIdx(ApplicationContext& ctx, const uint32_t idx)
236 {
237 if (idx >= NUMBER_OF_FILES) {
238 printf_err("Invalid idx %u (expected less than %u)\n",
239 idx, NUMBER_OF_FILES);
240 return false;
241 }
242
243 ctx.Set<uint32_t>("clipIndex", idx);
244 return true;
245 }
246
247 static bool _PresentInferenceResult(hal_platform& platform,
248 const std::vector<arm::app::asr::AsrResult>& results)
249 {
250 constexpr uint32_t dataPsnTxtStartX1 = 20;
251 constexpr uint32_t dataPsnTxtStartY1 = 60;
252 constexpr bool allow_multiple_lines = true;
253
254 platform.data_psn->set_text_color(COLOR_GREEN);
255
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100256 info("Final results:\n");
257 info("Total number of inferences: %zu\n", results.size());
alexander3c798932021-03-26 21:42:19 +0000258 /* Results from multiple inferences should be combined before processing. */
259 std::vector<arm::app::ClassificationResult> combinedResults;
260 for (auto& result : results) {
261 combinedResults.insert(combinedResults.end(),
262 result.m_resultVec.begin(),
263 result.m_resultVec.end());
264 }
265
266 /* Get each inference result string using the decoder. */
267 for (const auto & result : results) {
268 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
269
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100270 info("For timestamp: %f (inference #: %u); label: %s\n",
271 result.m_timeStamp, result.m_inferenceNumber,
272 infResultStr.c_str());
alexander3c798932021-03-26 21:42:19 +0000273 }
274
275 /* Get the decoded result for the combined result. */
276 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
277
278 platform.data_psn->present_data_text(
279 finalResultStr.c_str(), finalResultStr.size(),
280 dataPsnTxtStartX1, dataPsnTxtStartY1,
281 allow_multiple_lines);
282
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100283 info("Complete recognition: %s\n", finalResultStr.c_str());
alexander3c798932021-03-26 21:42:19 +0000284 return true;
285 }
286
287} /* namespace app */
288} /* namespace arm */