blob: 0c5ff39e562d338a31d6008bcccf7a3e8cce2854 [file] [log] [blame]
Richard Burton00553462021-11-10 16:27:14 +00001/*
Kshitij Sisodia2ea46232022-12-19 16:37:33 +00002 * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates
3 * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
Richard Burton00553462021-11-10 16:27:14 +00004 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
alexander31ae9f02022-02-10 16:15:54 +000017#include "UseCaseHandler.hpp"
Richard Burton00553462021-11-10 16:27:14 +000018#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000019#include "ImageUtils.hpp"
Richard Burton00553462021-11-10 16:27:14 +000020#include "InputFiles.hpp"
Richard Burton4e002792022-05-04 09:45:02 +010021#include "RNNoiseFeatureProcessor.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000022#include "RNNoiseModel.hpp"
Richard Burton4e002792022-05-04 09:45:02 +010023#include "RNNoiseProcessing.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000024#include "UseCaseCommonUtils.hpp"
25#include "hal.h"
alexander31ae9f02022-02-10 16:15:54 +000026#include "log_macros.h"
27
Richard Burton00553462021-11-10 16:27:14 +000028namespace arm {
29namespace app {
30
31 /**
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000032 * @brief Helper function to increment current audio clip features index.
33 * @param[in,out] ctx Pointer to the application context object.
34 **/
Richard Burton00553462021-11-10 16:27:14 +000035 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
36
Richard Burton00553462021-11-10 16:27:14 +000037 /* Noise reduction inference handler. */
38 bool NoiseReductionHandler(ApplicationContext& ctx, bool runAll)
39 {
40 constexpr uint32_t dataPsnTxtInfStartX = 20;
41 constexpr uint32_t dataPsnTxtInfStartY = 40;
42
43 /* Variables used for memory dumping. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000044 size_t memDumpMaxLen = 0;
45 uint8_t* memDumpBaseAddr = nullptr;
Richard Burton00553462021-11-10 16:27:14 +000046 size_t undefMemDumpBytesWritten = 0;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000047 size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten;
48 if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") &&
49 ctx.Has("MEM_DUMP_BYTE_WRITTEN")) {
50 memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN");
51 memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR");
Richard Burton00553462021-11-10 16:27:14 +000052 pMemDumpBytesWritten = ctx.Get<size_t*>("MEM_DUMP_BYTE_WRITTEN");
53 }
54 std::reference_wrapper<size_t> memDumpBytesWritten = std::ref(*pMemDumpBytesWritten);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000055 auto& profiler = ctx.Get<Profiler&>("profiler");
Richard Burton00553462021-11-10 16:27:14 +000056
57 /* Get model reference. */
58 auto& model = ctx.Get<RNNoiseModel&>("model");
59 if (!model.IsInited()) {
60 printf_err("Model is not initialised! Terminating processing.\n");
61 return false;
62 }
63
64 /* Populate Pre-Processing related parameters. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000065 auto audioFrameLen = ctx.Get<uint32_t>("frameLength");
66 auto audioFrameStride = ctx.Get<uint32_t>("frameStride");
Richard Burton00553462021-11-10 16:27:14 +000067 auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures");
68
69 TfLiteTensor* inputTensor = model.GetInputTensor(0);
70 if (nrNumInputFeatures != inputTensor->bytes) {
71 printf_err("Input features size must be equal to input tensor size."
72 " Feature size = %" PRIu32 ", Tensor size = %zu.\n",
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000073 nrNumInputFeatures,
74 inputTensor->bytes);
Richard Burton00553462021-11-10 16:27:14 +000075 return false;
76 }
77
78 TfLiteTensor* outputTensor = model.GetOutputTensor(model.m_indexForModelOutput);
79
80 /* Initial choice of index for WAV file. */
81 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
82
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000083 std::function<const int16_t*(const uint32_t)> audioAccessorFunc = GetAudioArray;
Richard Burton00553462021-11-10 16:27:14 +000084 if (ctx.Has("features")) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000085 audioAccessorFunc = ctx.Get<std::function<const int16_t*(const uint32_t)>>("features");
Richard Burton00553462021-11-10 16:27:14 +000086 }
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000087 std::function<uint32_t(const uint32_t)> audioSizeAccessorFunc = GetAudioArraySize;
Richard Burton00553462021-11-10 16:27:14 +000088 if (ctx.Has("featureSizes")) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000089 audioSizeAccessorFunc =
90 ctx.Get<std::function<uint32_t(const uint32_t)>>("featureSizes");
Richard Burton00553462021-11-10 16:27:14 +000091 }
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000092 std::function<const char*(const uint32_t)> audioFileAccessorFunc = GetFilename;
Richard Burton00553462021-11-10 16:27:14 +000093 if (ctx.Has("featureFileNames")) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000094 audioFileAccessorFunc =
95 ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames");
Richard Burton00553462021-11-10 16:27:14 +000096 }
Richard Burton4e002792022-05-04 09:45:02 +010097 do {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010098 hal_lcd_clear(COLOR_BLACK);
Richard Burton9b8d67a2021-12-10 12:32:51 +000099
Richard Burton00553462021-11-10 16:27:14 +0000100 auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000101 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
Richard Burton00553462021-11-10 16:27:14 +0000102
103 /* Creating a sliding window through the audio. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000104 auto audioDataSlider =
105 audio::SlidingWindow<const int16_t>(audioAccessorFunc(currentIndex),
106 audioSizeAccessorFunc(currentIndex),
107 audioFrameLen,
108 audioFrameStride);
Richard Burton00553462021-11-10 16:27:14 +0000109
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000110 info("Running inference on input feature map %" PRIu32 " => %s\n",
111 currentIndex,
Richard Burton00553462021-11-10 16:27:14 +0000112 audioFileAccessorFunc(currentIndex));
113
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000114 memDumpBytesWritten +=
115 DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex),
116 (audioDataSlider.TotalStrides() + 1) * audioFrameLen,
117 memDumpBaseAddr + memDumpBytesWritten,
118 memDumpMaxLen - memDumpBytesWritten);
Richard Burton00553462021-11-10 16:27:14 +0000119
Richard Burton4e002792022-05-04 09:45:02 +0100120 /* Set up pre and post-processing. */
121 std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor =
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000122 std::make_shared<rnn::RNNoiseFeatureProcessor>();
Richard Burton4e002792022-05-04 09:45:02 +0100123 std::shared_ptr<rnn::FrameFeatures> frameFeatures =
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000124 std::make_shared<rnn::FrameFeatures>();
Richard Burton00553462021-11-10 16:27:14 +0000125
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000126 RNNoisePreProcess preProcess =
127 RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures);
Richard Burton4e002792022-05-04 09:45:02 +0100128
129 std::vector<int16_t> denoisedAudioFrame(audioFrameLen);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000130 RNNoisePostProcess postProcess = RNNoisePostProcess(
131 outputTensor, denoisedAudioFrame, featureProcessor, frameFeatures);
Richard Burton4e002792022-05-04 09:45:02 +0100132
Richard Burton00553462021-11-10 16:27:14 +0000133 bool resetGRU = true;
134
135 while (audioDataSlider.HasNext()) {
136 const int16_t* inferenceWindow = audioDataSlider.Next();
Richard Burton00553462021-11-10 16:27:14 +0000137
Richard Burton4e002792022-05-04 09:45:02 +0100138 if (!preProcess.DoPreProcess(inferenceWindow, audioFrameLen)) {
139 printf_err("Pre-processing failed.");
140 return false;
141 }
Richard Burton00553462021-11-10 16:27:14 +0000142
143 /* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000144 if (resetGRU) {
Richard Burton00553462021-11-10 16:27:14 +0000145 model.ResetGruState();
146 } else {
147 /* Copying gru state outputs to gru state inputs.
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000148 * Call ResetGruState in between the sequence of inferences on unrelated input
149 * data. */
Richard Burton00553462021-11-10 16:27:14 +0000150 model.CopyGruStates();
151 }
152
Richard Burton00553462021-11-10 16:27:14 +0000153 /* Strings for presentation/logging. */
154 std::string str_inf{"Running inference... "};
155
156 /* Display message on the LCD - inference running. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000157 hal_lcd_display_text(str_inf.c_str(),
158 str_inf.size(),
159 dataPsnTxtInfStartX,
160 dataPsnTxtInfStartY,
161 false);
Richard Burton00553462021-11-10 16:27:14 +0000162
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000163 info("Inference %zu/%zu\n",
164 audioDataSlider.Index() + 1,
165 audioDataSlider.TotalStrides() + 1);
Richard Burton00553462021-11-10 16:27:14 +0000166
167 /* Run inference over this feature sliding window. */
Richard Burton4e002792022-05-04 09:45:02 +0100168 if (!RunInference(model, profiler)) {
169 printf_err("Inference failed.");
Richard Burton00553462021-11-10 16:27:14 +0000170 return false;
171 }
Richard Burton4e002792022-05-04 09:45:02 +0100172 resetGRU = false;
Richard Burton00553462021-11-10 16:27:14 +0000173
Richard Burton4e002792022-05-04 09:45:02 +0100174 /* Carry out post-processing. */
175 if (!postProcess.DoPostProcess()) {
176 printf_err("Post-processing failed.");
177 return false;
Richard Burton00553462021-11-10 16:27:14 +0000178 }
179
180 /* Erase. */
181 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000182 hal_lcd_display_text(str_inf.c_str(),
183 str_inf.size(),
184 dataPsnTxtInfStartX,
185 dataPsnTxtInfStartY,
186 false);
Richard Burton00553462021-11-10 16:27:14 +0000187
188 if (memDumpMaxLen > 0) {
Richard Burton4e002792022-05-04 09:45:02 +0100189 /* Dump final post processed output to memory. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000190 memDumpBytesWritten +=
191 DumpOutputDenoisedAudioFrame(denoisedAudioFrame,
192 memDumpBaseAddr + memDumpBytesWritten,
193 memDumpMaxLen - memDumpBytesWritten);
Richard Burton00553462021-11-10 16:27:14 +0000194 }
195 }
196
197 if (memDumpMaxLen > 0) {
198 /* Needed to not let the compiler complain about type mismatch. */
199 size_t valMemDumpBytesWritten = memDumpBytesWritten;
200 info("Output memory dump of %zu bytes written at address 0x%p\n",
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000201 valMemDumpBytesWritten,
202 startDumpAddress);
Richard Burton00553462021-11-10 16:27:14 +0000203 }
204
Richard Burton4e002792022-05-04 09:45:02 +0100205 /* Finish by dumping the footer. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000206 DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten,
207 memDumpMaxLen - memDumpBytesWritten);
Richard Burton00553462021-11-10 16:27:14 +0000208
Richard Burton9b8d67a2021-12-10 12:32:51 +0000209 info("All inferences for audio clip complete.\n");
Richard Burton00553462021-11-10 16:27:14 +0000210 profiler.PrintProfilingResult();
211 IncrementAppCtxClipIdx(ctx);
212
Ayaan Masood233cec02021-12-09 17:22:22 +0000213 std::string clearString{' '};
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000214 hal_lcd_display_text(clearString.c_str(),
215 clearString.size(),
216 dataPsnTxtInfStartX,
217 dataPsnTxtInfStartY,
218 false);
Ayaan Masood233cec02021-12-09 17:22:22 +0000219
220 std::string completeMsg{"Inference complete!"};
221
222 /* Display message on the LCD - inference complete. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000223 hal_lcd_display_text(completeMsg.c_str(),
224 completeMsg.size(),
225 dataPsnTxtInfStartX,
226 dataPsnTxtInfStartY,
227 false);
Ayaan Masood233cec02021-12-09 17:22:22 +0000228
Richard Burton00553462021-11-10 16:27:14 +0000229 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
230
231 return true;
232 }
233
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000234 size_t DumpDenoisedAudioHeader(const char* filename,
235 size_t dumpSize,
236 uint8_t* memAddress,
237 size_t memSize)
238 {
Richard Burton00553462021-11-10 16:27:14 +0000239
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000240 if (memAddress == nullptr) {
Richard Burton00553462021-11-10 16:27:14 +0000241 return 0;
242 }
243
244 int32_t filenameLength = strlen(filename);
245 size_t numBytesWritten = 0;
246 size_t numBytesToWrite = 0;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000247 int32_t dumpSizeByte = dumpSize * sizeof(int16_t);
248 bool overflow = false;
Richard Burton00553462021-11-10 16:27:14 +0000249
250 /* Write the filename length */
251 numBytesToWrite = sizeof(filenameLength);
252 if (memSize - numBytesToWrite > 0) {
253 std::memcpy(memAddress, &filenameLength, numBytesToWrite);
254 numBytesWritten += numBytesToWrite;
255 memSize -= numBytesWritten;
256 } else {
257 overflow = true;
258 }
259
260 /* Write file name */
261 numBytesToWrite = filenameLength;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000262 if (memSize - numBytesToWrite > 0) {
Richard Burton00553462021-11-10 16:27:14 +0000263 std::memcpy(memAddress + numBytesWritten, filename, numBytesToWrite);
264 numBytesWritten += numBytesToWrite;
265 memSize -= numBytesWritten;
266 } else {
267 overflow = true;
268 }
269
270 /* Write dumpSize in byte */
271 numBytesToWrite = sizeof(dumpSizeByte);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000272 if (memSize - numBytesToWrite > 0) {
Richard Burton00553462021-11-10 16:27:14 +0000273 std::memcpy(memAddress + numBytesWritten, &(dumpSizeByte), numBytesToWrite);
274 numBytesWritten += numBytesToWrite;
275 memSize -= numBytesWritten;
276 } else {
277 overflow = true;
278 }
279
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000280 if (false == overflow) {
281 info("Audio Clip dump header info (%zu bytes) written to %p\n",
282 numBytesWritten,
283 memAddress);
Richard Burton00553462021-11-10 16:27:14 +0000284 } else {
285 printf_err("Not enough memory to dump Audio Clip header.\n");
286 }
287
288 return numBytesWritten;
289 }
290
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000291 size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize)
292 {
Richard Burton00553462021-11-10 16:27:14 +0000293 if ((memAddress == nullptr) || (memSize < 4)) {
294 return 0;
295 }
296 const int32_t eofMarker = -1;
297 std::memcpy(memAddress, &eofMarker, sizeof(int32_t));
298
299 return sizeof(int32_t);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000300 }
Richard Burton00553462021-11-10 16:27:14 +0000301
Richard Burton4e002792022-05-04 09:45:02 +0100302 size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame,
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000303 uint8_t* memAddress,
304 size_t memSize)
Richard Burton00553462021-11-10 16:27:14 +0000305 {
306 if (memAddress == nullptr) {
307 return 0;
308 }
309
310 size_t numByteToBeWritten = audioFrame.size() * sizeof(int16_t);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000311 if (numByteToBeWritten > memSize) {
312 printf_err("Overflow error: Writing %zu of %zu bytes to memory @ 0x%p.\n",
313 memSize,
314 numByteToBeWritten,
315 memAddress);
Richard Burton00553462021-11-10 16:27:14 +0000316 numByteToBeWritten = memSize;
317 }
318
319 std::memcpy(memAddress, audioFrame.data(), numByteToBeWritten);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000320 info("Copied %zu bytes to %p\n", numByteToBeWritten, memAddress);
Richard Burton00553462021-11-10 16:27:14 +0000321
322 return numByteToBeWritten;
323 }
324
325 size_t DumpOutputTensorsToMemory(Model& model, uint8_t* memAddress, const size_t memSize)
326 {
327 const size_t numOutputs = model.GetNumOutputs();
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000328 size_t numBytesWritten = 0;
329 uint8_t* ptr = memAddress;
Richard Burton00553462021-11-10 16:27:14 +0000330
331 /* Iterate over all output tensors. */
332 for (size_t i = 0; i < numOutputs; ++i) {
333 const TfLiteTensor* tensor = model.GetOutputTensor(i);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000334 const auto* tData = tflite::GetTensorData<uint8_t>(tensor);
Richard Burton00553462021-11-10 16:27:14 +0000335#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100336 DumpTensor(tensor);
Richard Burton00553462021-11-10 16:27:14 +0000337#endif /* VERIFY_TEST_OUTPUT */
338 /* Ensure that we don't overflow the allowed limit. */
339 if (numBytesWritten + tensor->bytes <= memSize) {
340 if (tensor->bytes > 0) {
341 std::memcpy(ptr, tData, tensor->bytes);
342
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000343 info("Copied %zu bytes for tensor %zu to 0x%p\n", tensor->bytes, i, ptr);
Richard Burton00553462021-11-10 16:27:14 +0000344
345 numBytesWritten += tensor->bytes;
346 ptr += tensor->bytes;
347 }
348 } else {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000349 printf_err("Error writing tensor %zu to memory @ 0x%p\n", i, memAddress);
Richard Burton00553462021-11-10 16:27:14 +0000350 break;
351 }
352 }
353
354 info("%zu bytes written to memory @ 0x%p\n", numBytesWritten, memAddress);
355
356 return numBytesWritten;
357 }
358
359 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
360 {
361 auto curClipIdx = ctx.Get<uint32_t>("clipIndex");
362 if (curClipIdx + 1 >= NUMBER_OF_FILES) {
363 ctx.Set<uint32_t>("clipIndex", 0);
364 return;
365 }
366 ++curClipIdx;
367 ctx.Set<uint32_t>("clipIndex", curClipIdx);
368 }
369
Richard Burton00553462021-11-10 16:27:14 +0000370} /* namespace app */
371} /* namespace arm */