blob: f13de0d4fd56f1f3577b68ccdffd6f92777ef3eb [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "InputFiles.hpp"
20#include "AsrClassifier.hpp"
21#include "Wav2LetterModel.hpp"
22#include "hal.h"
23#include "Wav2LetterMfcc.hpp"
24#include "AudioUtils.hpp"
25#include "UseCaseCommonUtils.hpp"
26#include "AsrResult.hpp"
27#include "Wav2LetterPreprocess.hpp"
28#include "Wav2LetterPostprocess.hpp"
29#include "OutputDecode.hpp"
30
31namespace arm {
32namespace app {
33
34 /**
alexander3c798932021-03-26 21:42:19 +000035 * @brief Presents inference results using the data presentation
36 * object.
37 * @param[in] platform Reference to the hal platform object.
38 * @param[in] results Vector of classification results to be displayed.
alexander3c798932021-03-26 21:42:19 +000039 * @return true if successful, false otherwise.
40 **/
alexanderc350cdc2021-04-29 20:36:09 +010041 static bool PresentInferenceResult(
alexander3c798932021-03-26 21:42:19 +000042 hal_platform& platform,
43 const std::vector<arm::app::asr::AsrResult>& results);
44
45 /* Audio inference classification handler. */
46 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
47 {
48 constexpr uint32_t dataPsnTxtInfStartX = 20;
49 constexpr uint32_t dataPsnTxtInfStartY = 40;
50
51 auto& platform = ctx.Get<hal_platform&>("platform");
52 platform.data_psn->clear(COLOR_BLACK);
53
Isabella Gottardi8df12f32021-04-07 17:15:31 +010054 auto& profiler = ctx.Get<Profiler&>("profiler");
55
alexander3c798932021-03-26 21:42:19 +000056 /* If the request has a valid size, set the audio index. */
57 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010058 if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
alexander3c798932021-03-26 21:42:19 +000059 return false;
60 }
61 }
62
63 /* Get model reference. */
64 auto& model = ctx.Get<Model&>("model");
65 if (!model.IsInited()) {
66 printf_err("Model is not initialised! Terminating processing.\n");
67 return false;
68 }
69
70 /* Get score threshold to be applied for the classifier (post-inference). */
71 auto scoreThreshold = ctx.Get<float>("scoreThreshold");
72
73 /* Get tensors. Dimensions of the tensor should have been verified by
74 * the callee. */
75 TfLiteTensor* inputTensor = model.GetInputTensor(0);
76 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
77 const uint32_t inputRows = inputTensor->dims->data[arm::app::Wav2LetterModel::ms_inputRowsIdx];
78
79 /* Populate MFCC related parameters. */
80 auto mfccParamsWinLen = ctx.Get<uint32_t>("frameLength");
81 auto mfccParamsWinStride = ctx.Get<uint32_t>("frameStride");
82
83 /* Populate ASR inference context and inner lengths for input. */
84 auto inputCtxLen = ctx.Get<uint32_t>("ctxLen");
85 const uint32_t inputInnerLen = inputRows - (2 * inputCtxLen);
86
87 /* Audio data stride corresponds to inputInnerLen feature vectors. */
88 const uint32_t audioParamsWinLen = (inputRows - 1) * mfccParamsWinStride + (mfccParamsWinLen);
89 const uint32_t audioParamsWinStride = inputInnerLen * mfccParamsWinStride;
90 const float audioParamsSecondsPerSample = (1.0/audio::Wav2LetterMFCC::ms_defaultSamplingFreq);
91
92 /* Get pre/post-processing objects. */
93 auto& prep = ctx.Get<audio::asr::Preprocess&>("preprocess");
94 auto& postp = ctx.Get<audio::asr::Postprocess&>("postprocess");
95
96 /* Set default reduction axis for post-processing. */
97 const uint32_t reductionAxis = arm::app::Wav2LetterModel::ms_outputRowsIdx;
98
99 /* Audio clip start index. */
100 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
101
102 /* Loop to process audio clips. */
103 do {
Richard Burton9b8d67a2021-12-10 12:32:51 +0000104 platform.data_psn->clear(COLOR_BLACK);
105
alexander3c798932021-03-26 21:42:19 +0000106 /* Get current audio clip index. */
107 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
108
109 /* Get the current audio buffer and respective size. */
110 const int16_t* audioArr = get_audio_array(currentIndex);
111 const uint32_t audioArrSize = get_audio_array_size(currentIndex);
112
113 if (!audioArr) {
114 printf_err("Invalid audio array pointer\n");
115 return false;
116 }
117
118 /* Audio clip must have enough samples to produce 1 MFCC feature. */
119 if (audioArrSize < mfccParamsWinLen) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100120 printf_err("Not enough audio samples, minimum needed is %" PRIu32 "\n",
121 mfccParamsWinLen);
alexander3c798932021-03-26 21:42:19 +0000122 return false;
123 }
124
125 /* Initialise an audio slider. */
alexander80eecfb2021-07-06 19:47:59 +0100126 auto audioDataSlider = audio::FractionalSlidingWindow<const int16_t>(
alexander3c798932021-03-26 21:42:19 +0000127 audioArr,
128 audioArrSize,
129 audioParamsWinLen,
130 audioParamsWinStride);
131
132 /* Declare a container for results. */
133 std::vector<arm::app::asr::AsrResult> results;
134
135 /* Display message on the LCD - inference running. */
136 std::string str_inf{"Running inference... "};
137 platform.data_psn->present_data_text(
138 str_inf.c_str(), str_inf.size(),
139 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
140
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100141 info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex,
alexander3c798932021-03-26 21:42:19 +0000142 get_filename(currentIndex));
143
144 size_t inferenceWindowLen = audioParamsWinLen;
145
146 /* Start sliding through audio clip. */
147 while (audioDataSlider.HasNext()) {
148
149 /* If not enough audio see how much can be sent for processing. */
150 size_t nextStartIndex = audioDataSlider.NextWindowStartIndex();
151 if (nextStartIndex + audioParamsWinLen > audioArrSize) {
152 inferenceWindowLen = audioArrSize - nextStartIndex;
153 }
154
155 const int16_t* inferenceWindow = audioDataSlider.Next();
156
157 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
158 static_cast<size_t>(ceilf(audioDataSlider.FractionalTotalStrides() + 1)));
159
alexander3c798932021-03-26 21:42:19 +0000160 /* Calculate MFCCs, deltas and populate the input tensor. */
161 prep.Invoke(inferenceWindow, inferenceWindowLen, inputTensor);
162
alexander3c798932021-03-26 21:42:19 +0000163 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100164 if (!RunInference(model, profiler)) {
165 return false;
166 }
alexander3c798932021-03-26 21:42:19 +0000167
168 /* Post-process. */
169 postp.Invoke(outputTensor, reductionAxis, !audioDataSlider.HasNext());
170
171 /* Get results. */
172 std::vector<ClassificationResult> classificationResult;
173 auto& classifier = ctx.Get<AsrClassifier&>("classifier");
174 classifier.GetClassificationResults(
175 outputTensor, classificationResult,
176 ctx.Get<std::vector<std::string>&>("labels"), 1);
177
178 results.emplace_back(asr::AsrResult(classificationResult,
179 (audioDataSlider.Index() *
180 audioParamsSecondsPerSample *
181 audioParamsWinStride),
182 audioDataSlider.Index(), scoreThreshold));
183
184#if VERIFY_TEST_OUTPUT
185 arm::app::DumpTensor(outputTensor,
186 outputTensor->dims->data[arm::app::Wav2LetterModel::ms_outputColsIdx]);
187#endif /* VERIFY_TEST_OUTPUT */
188
189 }
190
191 /* Erase. */
192 str_inf = std::string(str_inf.size(), ' ');
193 platform.data_psn->present_data_text(
194 str_inf.c_str(), str_inf.size(),
195 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
196
197 ctx.Set<std::vector<arm::app::asr::AsrResult>>("results", results);
198
alexanderc350cdc2021-04-29 20:36:09 +0100199 if (!PresentInferenceResult(platform, results)) {
alexander3c798932021-03-26 21:42:19 +0000200 return false;
201 }
202
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100203 profiler.PrintProfilingResult();
204
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100205 IncrementAppCtxIfmIdx(ctx,"clipIndex");
alexander3c798932021-03-26 21:42:19 +0000206
207 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
208
209 return true;
210 }
211
alexander3c798932021-03-26 21:42:19 +0000212
alexanderc350cdc2021-04-29 20:36:09 +0100213 static bool PresentInferenceResult(hal_platform& platform,
214 const std::vector<arm::app::asr::AsrResult>& results)
alexander3c798932021-03-26 21:42:19 +0000215 {
216 constexpr uint32_t dataPsnTxtStartX1 = 20;
217 constexpr uint32_t dataPsnTxtStartY1 = 60;
218 constexpr bool allow_multiple_lines = true;
219
220 platform.data_psn->set_text_color(COLOR_GREEN);
221
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100222 info("Final results:\n");
223 info("Total number of inferences: %zu\n", results.size());
alexander3c798932021-03-26 21:42:19 +0000224 /* Results from multiple inferences should be combined before processing. */
225 std::vector<arm::app::ClassificationResult> combinedResults;
226 for (auto& result : results) {
227 combinedResults.insert(combinedResults.end(),
228 result.m_resultVec.begin(),
229 result.m_resultVec.end());
230 }
231
232 /* Get each inference result string using the decoder. */
233 for (const auto & result : results) {
234 std::string infResultStr = audio::asr::DecodeOutput(result.m_resultVec);
235
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100236 info("For timestamp: %f (inference #: %" PRIu32 "); label: %s\n",
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100237 result.m_timeStamp, result.m_inferenceNumber,
238 infResultStr.c_str());
alexander3c798932021-03-26 21:42:19 +0000239 }
240
241 /* Get the decoded result for the combined result. */
242 std::string finalResultStr = audio::asr::DecodeOutput(combinedResults);
243
244 platform.data_psn->present_data_text(
245 finalResultStr.c_str(), finalResultStr.size(),
246 dataPsnTxtStartX1, dataPsnTxtStartY1,
247 allow_multiple_lines);
248
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100249 info("Complete recognition: %s\n", finalResultStr.c_str());
alexander3c798932021-03-26 21:42:19 +0000250 return true;
251 }
252
253} /* namespace app */
254} /* namespace arm */