blob: ce99ed373e6f9ce6ebd00c51e31aa6b6f0ae7041 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Kshitij Sisodia2ea46232022-12-19 16:37:33 +00002 * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates
3 * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
alexander3c798932021-03-26 21:42:19 +00004 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
alexander3c798932021-03-26 21:42:19 +000019#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000020#include "ImageUtils.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000021#include "InputFiles.hpp"
22#include "KwsClassifier.hpp"
Richard Burtone6398cd2022-04-13 11:58:28 +010023#include "KwsProcessing.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000024#include "KwsResult.hpp"
25#include "MicroNetKwsModel.hpp"
26#include "UseCaseCommonUtils.hpp"
27#include "hal.h"
28#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000029
30#include <vector>
alexander3c798932021-03-26 21:42:19 +000031
alexander3c798932021-03-26 21:42:19 +000032namespace arm {
33namespace app {
34
alexander3c798932021-03-26 21:42:19 +000035 /**
Richard Burtone6398cd2022-04-13 11:58:28 +010036 * @brief Presents KWS inference results.
37 * @param[in] results Vector of KWS classification results to be displayed.
alexander3c798932021-03-26 21:42:19 +000038 * @return true if successful, false otherwise.
39 **/
Richard Burtonb40ecf82022-04-22 16:14:57 +010040 static bool PresentInferenceResult(const std::vector<kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000041
Richard Burtone6398cd2022-04-13 11:58:28 +010042 /* KWS inference handler. */
alexander3c798932021-03-26 21:42:19 +000043 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
44 {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000045 auto& profiler = ctx.Get<Profiler&>("profiler");
46 auto& model = ctx.Get<Model&>("model");
Richard Burtone6398cd2022-04-13 11:58:28 +010047 const auto mfccFrameLength = ctx.Get<int>("frameLength");
48 const auto mfccFrameStride = ctx.Get<int>("frameStride");
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000049 const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
Richard Burtonb40ecf82022-04-22 16:14:57 +010050
Richard Burtone6398cd2022-04-13 11:58:28 +010051 /* If the request has a valid size, set the audio index. */
52 if (clipIndex < NUMBER_OF_FILES) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000053 if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) {
Richard Burtone6398cd2022-04-13 11:58:28 +010054 return false;
55 }
56 }
57 auto initialClipIdx = ctx.Get<uint32_t>("clipIndex");
alexander3c798932021-03-26 21:42:19 +000058
59 constexpr uint32_t dataPsnTxtInfStartX = 20;
60 constexpr uint32_t dataPsnTxtInfStartY = 40;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000061 constexpr int minTensorDims =
62 static_cast<int>((MicroNetKwsModel::ms_inputRowsIdx > MicroNetKwsModel::ms_inputColsIdx)
63 ? MicroNetKwsModel::ms_inputRowsIdx
64 : MicroNetKwsModel::ms_inputColsIdx);
alexander3c798932021-03-26 21:42:19 +000065
alexander3c798932021-03-26 21:42:19 +000066 if (!model.IsInited()) {
67 printf_err("Model is not initialised! Terminating processing.\n");
68 return false;
69 }
70
Richard Burtonb40ecf82022-04-22 16:14:57 +010071 /* Get Input and Output tensors for pre/post processing. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000072 TfLiteTensor* inputTensor = model.GetInputTensor(0);
Richard Burtonb40ecf82022-04-22 16:14:57 +010073 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
alexander3c798932021-03-26 21:42:19 +000074 if (!inputTensor->dims) {
75 printf_err("Invalid input tensor dims\n");
76 return false;
77 } else if (inputTensor->dims->size < minTensorDims) {
78 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
79 return false;
80 }
81
Richard Burtone6398cd2022-04-13 11:58:28 +010082 /* Get input shape for feature extraction. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000083 TfLiteIntArray* inputShape = model.GetInputShape(0);
Richard Burtonb40ecf82022-04-22 16:14:57 +010084 const uint32_t numMfccFeatures = inputShape->data[MicroNetKwsModel::ms_inputColsIdx];
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000085 const uint32_t numMfccFrames =
86 inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +000087
88 /* We expect to be sampling 1 second worth of data at a time.
89 * NOTE: This is only used for time stamp calculation. */
Richard Burtone6398cd2022-04-13 11:58:28 +010090 const float secondsPerSample = 1.0 / audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
91
92 /* Set up pre and post-processing. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000093 KwsPreProcess preProcess = KwsPreProcess(
94 inputTensor, numMfccFeatures, numMfccFrames, mfccFrameLength, mfccFrameStride);
Richard Burtone6398cd2022-04-13 11:58:28 +010095
96 std::vector<ClassificationResult> singleInfResult;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000097 KwsPostProcess postProcess = KwsPostProcess(outputTensor,
98 ctx.Get<KwsClassifier&>("classifier"),
Richard Burtone6398cd2022-04-13 11:58:28 +010099 ctx.Get<std::vector<std::string>&>("labels"),
Richard Burtonc2911442022-04-22 09:08:21 +0100100 singleInfResult);
Richard Burtone6398cd2022-04-13 11:58:28 +0100101
Richard Burtonb40ecf82022-04-22 16:14:57 +0100102 /* Loop to process audio clips. */
alexander3c798932021-03-26 21:42:19 +0000103 do {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100104 hal_lcd_clear(COLOR_BLACK);
Richard Burton9b8d67a2021-12-10 12:32:51 +0000105
alexander3c798932021-03-26 21:42:19 +0000106 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
107
alexander3c798932021-03-26 21:42:19 +0000108 /* Creating a sliding window through the whole audio clip. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000109 auto audioDataSlider =
110 audio::SlidingWindow<const int16_t>(GetAudioArray(currentIndex),
111 GetAudioArraySize(currentIndex),
112 preProcess.m_audioDataWindowSize,
113 preProcess.m_audioDataStride);
alexander3c798932021-03-26 21:42:19 +0000114
Richard Burtone6398cd2022-04-13 11:58:28 +0100115 /* Declare a container to hold results from across the whole audio clip. */
116 std::vector<kws::KwsResult> finalResults;
alexander3c798932021-03-26 21:42:19 +0000117
118 /* Display message on the LCD - inference running. */
119 std::string str_inf{"Running inference... "};
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000120 hal_lcd_display_text(
121 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
122 info("Running inference on audio clip %" PRIu32 " => %s\n",
123 currentIndex,
124 GetFilename(currentIndex));
alexander3c798932021-03-26 21:42:19 +0000125
126 /* Start sliding through audio clip. */
127 while (audioDataSlider.HasNext()) {
Richard Burtone6398cd2022-04-13 11:58:28 +0100128 const int16_t* inferenceWindow = audioDataSlider.Next();
alexander3c798932021-03-26 21:42:19 +0000129
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000130 info("Inference %zu/%zu\n",
131 audioDataSlider.Index() + 1,
alexander3c798932021-03-26 21:42:19 +0000132 audioDataSlider.TotalStrides() + 1);
133
Richard Burtone6398cd2022-04-13 11:58:28 +0100134 /* Run the pre-processing, inference and post-processing. */
Richard Burtonec5e99b2022-10-05 11:00:37 +0100135 if (!preProcess.DoPreProcess(inferenceWindow, audioDataSlider.Index())) {
Richard Burtonb40ecf82022-04-22 16:14:57 +0100136 printf_err("Pre-processing failed.");
alexander27b62d92021-05-04 20:46:08 +0100137 return false;
138 }
alexander3c798932021-03-26 21:42:19 +0000139
Richard Burtonb40ecf82022-04-22 16:14:57 +0100140 if (!RunInference(model, profiler)) {
141 printf_err("Inference failed.");
Richard Burtone6398cd2022-04-13 11:58:28 +0100142 return false;
143 }
alexander3c798932021-03-26 21:42:19 +0000144
Richard Burtonb40ecf82022-04-22 16:14:57 +0100145 if (!postProcess.DoPostProcess()) {
146 printf_err("Post-processing failed.");
Richard Burtone6398cd2022-04-13 11:58:28 +0100147 return false;
148 }
149
150 /* Add results from this window to our final results vector. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000151 finalResults.emplace_back(kws::KwsResult(
152 singleInfResult,
153 audioDataSlider.Index() * secondsPerSample * preProcess.m_audioDataStride,
154 audioDataSlider.Index(),
155 scoreThreshold));
alexander3c798932021-03-26 21:42:19 +0000156
157#if VERIFY_TEST_OUTPUT
Richard Burtonb40ecf82022-04-22 16:14:57 +0100158 DumpTensor(outputTensor);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000159#endif /* VERIFY_TEST_OUTPUT */
alexander3c798932021-03-26 21:42:19 +0000160 } /* while (audioDataSlider.HasNext()) */
161
162 /* Erase. */
163 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000164 hal_lcd_display_text(
165 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
alexander3c798932021-03-26 21:42:19 +0000166
Richard Burtone6398cd2022-04-13 11:58:28 +0100167 ctx.Set<std::vector<kws::KwsResult>>("results", finalResults);
alexander3c798932021-03-26 21:42:19 +0000168
Richard Burtone6398cd2022-04-13 11:58:28 +0100169 if (!PresentInferenceResult(finalResults)) {
alexander3c798932021-03-26 21:42:19 +0000170 return false;
171 }
172
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100173 profiler.PrintProfilingResult();
174
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000175 IncrementAppCtxIfmIdx(ctx, "clipIndex");
alexander3c798932021-03-26 21:42:19 +0000176
Richard Burtone6398cd2022-04-13 11:58:28 +0100177 } while (runAll && ctx.Get<uint32_t>("clipIndex") != initialClipIdx);
alexander3c798932021-03-26 21:42:19 +0000178
179 return true;
180 }
181
Richard Burtonb40ecf82022-04-22 16:14:57 +0100182 static bool PresentInferenceResult(const std::vector<kws::KwsResult>& results)
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100183 {
184 constexpr uint32_t dataPsnTxtStartX1 = 20;
185 constexpr uint32_t dataPsnTxtStartY1 = 30;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000186 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100187
188 hal_lcd_set_text_color(COLOR_GREEN);
189 info("Final results:\n");
190 info("Total number of inferences: %zu\n", results.size());
191
192 /* Display each result */
193 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
194
Richard Burtonb40ecf82022-04-22 16:14:57 +0100195 for (const auto& result : results) {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100196
197 std::string topKeyword{"<none>"};
198 float score = 0.f;
Richard Burtone6398cd2022-04-13 11:58:28 +0100199 if (!result.m_resultVec.empty()) {
200 topKeyword = result.m_resultVec[0].m_label;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000201 score = result.m_resultVec[0].m_normalisedVal;
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100202 }
203
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000204 std::string resultStr = std::string{"@"} + std::to_string(result.m_timeStamp) +
205 std::string{"s: "} + topKeyword + std::string{" ("} +
206 std::to_string(static_cast<int>(score * 100)) +
207 std::string{"%)"};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100208
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000209 hal_lcd_display_text(
210 resultStr.c_str(), resultStr.size(), dataPsnTxtStartX1, rowIdx1, false);
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100211 rowIdx1 += dataPsnTxtYIncr;
212
Richard Burtone6398cd2022-04-13 11:58:28 +0100213 if (result.m_resultVec.empty()) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000214 info("For timestamp: %f (inference #: %" PRIu32 "); label: %s; threshold: %f\n",
215 result.m_timeStamp,
216 result.m_inferenceNumber,
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100217 topKeyword.c_str(),
Richard Burtone6398cd2022-04-13 11:58:28 +0100218 result.m_threshold);
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100219 } else {
Richard Burtone6398cd2022-04-13 11:58:28 +0100220 for (uint32_t j = 0; j < result.m_resultVec.size(); ++j) {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100221 info("For timestamp: %f (inference #: %" PRIu32
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000222 "); label: %s, score: %f; threshold: %f\n",
Richard Burtone6398cd2022-04-13 11:58:28 +0100223 result.m_timeStamp,
224 result.m_inferenceNumber,
225 result.m_resultVec[j].m_label.c_str(),
226 result.m_resultVec[j].m_normalisedVal,
227 result.m_threshold);
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100228 }
229 }
230 }
231
232 return true;
233 }
234
alexander3c798932021-03-26 21:42:19 +0000235} /* namespace app */
Richard Burtone6398cd2022-04-13 11:58:28 +0100236} /* namespace arm */