blob: 53bb43ea0e702126c59169bf33583acd2c0d0c31 [file] [log] [blame]
Richard Burton00553462021-11-10 16:27:14 +00001/*
Richard Burtoned35a6f2022-02-14 11:55:35 +00002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
Richard Burton00553462021-11-10 16:27:14 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
Richard Burton00553462021-11-10 16:27:14 +000017#include "hal.h"
alexander31ae9f02022-02-10 16:15:54 +000018#include "UseCaseHandler.hpp"
Richard Burton00553462021-11-10 16:27:14 +000019#include "UseCaseCommonUtils.hpp"
20#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000021#include "ImageUtils.hpp"
Richard Burton00553462021-11-10 16:27:14 +000022#include "InputFiles.hpp"
23#include "RNNoiseModel.hpp"
Richard Burton4e002792022-05-04 09:45:02 +010024#include "RNNoiseFeatureProcessor.hpp"
25#include "RNNoiseProcessing.hpp"
alexander31ae9f02022-02-10 16:15:54 +000026#include "log_macros.h"
27
Richard Burton00553462021-11-10 16:27:14 +000028namespace arm {
29namespace app {
30
31 /**
32 * @brief Helper function to increment current audio clip features index.
33 * @param[in,out] ctx Pointer to the application context object.
34 **/
35 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
36
Richard Burton00553462021-11-10 16:27:14 +000037 /* Noise reduction inference handler. */
38 bool NoiseReductionHandler(ApplicationContext& ctx, bool runAll)
39 {
40 constexpr uint32_t dataPsnTxtInfStartX = 20;
41 constexpr uint32_t dataPsnTxtInfStartY = 40;
42
43 /* Variables used for memory dumping. */
44 size_t memDumpMaxLen = 0;
45 uint8_t* memDumpBaseAddr = nullptr;
46 size_t undefMemDumpBytesWritten = 0;
Richard Burton4e002792022-05-04 09:45:02 +010047 size_t* pMemDumpBytesWritten = &undefMemDumpBytesWritten;
Richard Burton00553462021-11-10 16:27:14 +000048 if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) {
49 memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN");
50 memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR");
51 pMemDumpBytesWritten = ctx.Get<size_t*>("MEM_DUMP_BYTE_WRITTEN");
52 }
53 std::reference_wrapper<size_t> memDumpBytesWritten = std::ref(*pMemDumpBytesWritten);
Richard Burton00553462021-11-10 16:27:14 +000054 auto& profiler = ctx.Get<Profiler&>("profiler");
55
56 /* Get model reference. */
57 auto& model = ctx.Get<RNNoiseModel&>("model");
58 if (!model.IsInited()) {
59 printf_err("Model is not initialised! Terminating processing.\n");
60 return false;
61 }
62
63 /* Populate Pre-Processing related parameters. */
Richard Burton4e002792022-05-04 09:45:02 +010064 auto audioFrameLen = ctx.Get<uint32_t>("frameLength");
65 auto audioFrameStride = ctx.Get<uint32_t>("frameStride");
Richard Burton00553462021-11-10 16:27:14 +000066 auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures");
67
68 TfLiteTensor* inputTensor = model.GetInputTensor(0);
69 if (nrNumInputFeatures != inputTensor->bytes) {
70 printf_err("Input features size must be equal to input tensor size."
71 " Feature size = %" PRIu32 ", Tensor size = %zu.\n",
72 nrNumInputFeatures, inputTensor->bytes);
73 return false;
74 }
75
76 TfLiteTensor* outputTensor = model.GetOutputTensor(model.m_indexForModelOutput);
77
78 /* Initial choice of index for WAV file. */
79 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
80
81 std::function<const int16_t* (const uint32_t)> audioAccessorFunc = get_audio_array;
82 if (ctx.Has("features")) {
83 audioAccessorFunc = ctx.Get<std::function<const int16_t* (const uint32_t)>>("features");
84 }
85 std::function<uint32_t (const uint32_t)> audioSizeAccessorFunc = get_audio_array_size;
86 if (ctx.Has("featureSizes")) {
87 audioSizeAccessorFunc = ctx.Get<std::function<uint32_t (const uint32_t)>>("featureSizes");
88 }
89 std::function<const char*(const uint32_t)> audioFileAccessorFunc = get_filename;
90 if (ctx.Has("featureFileNames")) {
91 audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames");
92 }
Richard Burton4e002792022-05-04 09:45:02 +010093 do {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010094 hal_lcd_clear(COLOR_BLACK);
Richard Burton9b8d67a2021-12-10 12:32:51 +000095
Richard Burton00553462021-11-10 16:27:14 +000096 auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten;
97 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
98
99 /* Creating a sliding window through the audio. */
100 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
101 audioAccessorFunc(currentIndex),
Richard Burton4e002792022-05-04 09:45:02 +0100102 audioSizeAccessorFunc(currentIndex), audioFrameLen,
103 audioFrameStride);
Richard Burton00553462021-11-10 16:27:14 +0000104
105 info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex,
106 audioFileAccessorFunc(currentIndex));
107
108 memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex),
Richard Burton4e002792022-05-04 09:45:02 +0100109 (audioDataSlider.TotalStrides() + 1) * audioFrameLen,
Richard Burton00553462021-11-10 16:27:14 +0000110 memDumpBaseAddr + memDumpBytesWritten,
111 memDumpMaxLen - memDumpBytesWritten);
112
Richard Burton4e002792022-05-04 09:45:02 +0100113 /* Set up pre and post-processing. */
114 std::shared_ptr<rnn::RNNoiseFeatureProcessor> featureProcessor =
115 std::make_shared<rnn::RNNoiseFeatureProcessor>();
116 std::shared_ptr<rnn::FrameFeatures> frameFeatures =
117 std::make_shared<rnn::FrameFeatures>();
Richard Burton00553462021-11-10 16:27:14 +0000118
Richard Burton4e002792022-05-04 09:45:02 +0100119 RNNoisePreProcess preProcess = RNNoisePreProcess(inputTensor, featureProcessor, frameFeatures);
120
121 std::vector<int16_t> denoisedAudioFrame(audioFrameLen);
122 RNNoisePostProcess postProcess = RNNoisePostProcess(outputTensor, denoisedAudioFrame,
123 featureProcessor, frameFeatures);
124
Richard Burton00553462021-11-10 16:27:14 +0000125 bool resetGRU = true;
126
127 while (audioDataSlider.HasNext()) {
128 const int16_t* inferenceWindow = audioDataSlider.Next();
Richard Burton00553462021-11-10 16:27:14 +0000129
Richard Burton4e002792022-05-04 09:45:02 +0100130 if (!preProcess.DoPreProcess(inferenceWindow, audioFrameLen)) {
131 printf_err("Pre-processing failed.");
132 return false;
133 }
Richard Burton00553462021-11-10 16:27:14 +0000134
135 /* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */
136 if (resetGRU){
137 model.ResetGruState();
138 } else {
139 /* Copying gru state outputs to gru state inputs.
140 * Call ResetGruState in between the sequence of inferences on unrelated input data. */
141 model.CopyGruStates();
142 }
143
Richard Burton00553462021-11-10 16:27:14 +0000144 /* Strings for presentation/logging. */
145 std::string str_inf{"Running inference... "};
146
147 /* Display message on the LCD - inference running. */
Richard Burton4e002792022-05-04 09:45:02 +0100148 hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
149 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
Richard Burton00553462021-11-10 16:27:14 +0000150
151 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1);
152
153 /* Run inference over this feature sliding window. */
Richard Burton4e002792022-05-04 09:45:02 +0100154 if (!RunInference(model, profiler)) {
155 printf_err("Inference failed.");
Richard Burton00553462021-11-10 16:27:14 +0000156 return false;
157 }
Richard Burton4e002792022-05-04 09:45:02 +0100158 resetGRU = false;
Richard Burton00553462021-11-10 16:27:14 +0000159
Richard Burton4e002792022-05-04 09:45:02 +0100160 /* Carry out post-processing. */
161 if (!postProcess.DoPostProcess()) {
162 printf_err("Post-processing failed.");
163 return false;
Richard Burton00553462021-11-10 16:27:14 +0000164 }
165
166 /* Erase. */
167 str_inf = std::string(str_inf.size(), ' ');
Richard Burton4e002792022-05-04 09:45:02 +0100168 hal_lcd_display_text(str_inf.c_str(), str_inf.size(),
169 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
Richard Burton00553462021-11-10 16:27:14 +0000170
171 if (memDumpMaxLen > 0) {
Richard Burton4e002792022-05-04 09:45:02 +0100172 /* Dump final post processed output to memory. */
Richard Burton00553462021-11-10 16:27:14 +0000173 memDumpBytesWritten += DumpOutputDenoisedAudioFrame(
174 denoisedAudioFrame,
175 memDumpBaseAddr + memDumpBytesWritten,
176 memDumpMaxLen - memDumpBytesWritten);
177 }
178 }
179
180 if (memDumpMaxLen > 0) {
181 /* Needed to not let the compiler complain about type mismatch. */
182 size_t valMemDumpBytesWritten = memDumpBytesWritten;
183 info("Output memory dump of %zu bytes written at address 0x%p\n",
184 valMemDumpBytesWritten, startDumpAddress);
185 }
186
Richard Burton4e002792022-05-04 09:45:02 +0100187 /* Finish by dumping the footer. */
Richard Burton00553462021-11-10 16:27:14 +0000188 DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten);
189
Richard Burton9b8d67a2021-12-10 12:32:51 +0000190 info("All inferences for audio clip complete.\n");
Richard Burton00553462021-11-10 16:27:14 +0000191 profiler.PrintProfilingResult();
192 IncrementAppCtxClipIdx(ctx);
193
Ayaan Masood233cec02021-12-09 17:22:22 +0000194 std::string clearString{' '};
Richard Burton4e002792022-05-04 09:45:02 +0100195 hal_lcd_display_text(clearString.c_str(), clearString.size(),
Ayaan Masood233cec02021-12-09 17:22:22 +0000196 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
197
198 std::string completeMsg{"Inference complete!"};
199
200 /* Display message on the LCD - inference complete. */
Richard Burton4e002792022-05-04 09:45:02 +0100201 hal_lcd_display_text(completeMsg.c_str(), completeMsg.size(),
Ayaan Masood233cec02021-12-09 17:22:22 +0000202 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
203
Richard Burton00553462021-11-10 16:27:14 +0000204 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
205
206 return true;
207 }
208
209 size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize,
Richard Burton4e002792022-05-04 09:45:02 +0100210 uint8_t* memAddress, size_t memSize){
Richard Burton00553462021-11-10 16:27:14 +0000211
212 if (memAddress == nullptr){
213 return 0;
214 }
215
216 int32_t filenameLength = strlen(filename);
217 size_t numBytesWritten = 0;
218 size_t numBytesToWrite = 0;
219 int32_t dumpSizeByte = dumpSize * sizeof(int16_t);
220 bool overflow = false;
221
222 /* Write the filename length */
223 numBytesToWrite = sizeof(filenameLength);
224 if (memSize - numBytesToWrite > 0) {
225 std::memcpy(memAddress, &filenameLength, numBytesToWrite);
226 numBytesWritten += numBytesToWrite;
227 memSize -= numBytesWritten;
228 } else {
229 overflow = true;
230 }
231
232 /* Write file name */
233 numBytesToWrite = filenameLength;
234 if(memSize - numBytesToWrite > 0) {
235 std::memcpy(memAddress + numBytesWritten, filename, numBytesToWrite);
236 numBytesWritten += numBytesToWrite;
237 memSize -= numBytesWritten;
238 } else {
239 overflow = true;
240 }
241
242 /* Write dumpSize in byte */
243 numBytesToWrite = sizeof(dumpSizeByte);
244 if(memSize - numBytesToWrite > 0) {
245 std::memcpy(memAddress + numBytesWritten, &(dumpSizeByte), numBytesToWrite);
246 numBytesWritten += numBytesToWrite;
247 memSize -= numBytesWritten;
248 } else {
249 overflow = true;
250 }
251
252 if(false == overflow) {
253 info("Audio Clip dump header info (%zu bytes) written to %p\n", numBytesWritten, memAddress);
254 } else {
255 printf_err("Not enough memory to dump Audio Clip header.\n");
256 }
257
258 return numBytesWritten;
259 }
260
Richard Burton4e002792022-05-04 09:45:02 +0100261 size_t DumpDenoisedAudioFooter(uint8_t* memAddress, size_t memSize){
Richard Burton00553462021-11-10 16:27:14 +0000262 if ((memAddress == nullptr) || (memSize < 4)) {
263 return 0;
264 }
265 const int32_t eofMarker = -1;
266 std::memcpy(memAddress, &eofMarker, sizeof(int32_t));
267
268 return sizeof(int32_t);
269 }
270
Richard Burton4e002792022-05-04 09:45:02 +0100271 size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t>& audioFrame,
272 uint8_t* memAddress, size_t memSize)
Richard Burton00553462021-11-10 16:27:14 +0000273 {
274 if (memAddress == nullptr) {
275 return 0;
276 }
277
278 size_t numByteToBeWritten = audioFrame.size() * sizeof(int16_t);
279 if( numByteToBeWritten > memSize) {
George Gekova2b0fc22021-11-08 16:30:43 +0000280 printf_err("Overflow error: Writing %zu of %zu bytes to memory @ 0x%p.\n", memSize, numByteToBeWritten, memAddress);
Richard Burton00553462021-11-10 16:27:14 +0000281 numByteToBeWritten = memSize;
282 }
283
284 std::memcpy(memAddress, audioFrame.data(), numByteToBeWritten);
285 info("Copied %zu bytes to %p\n", numByteToBeWritten, memAddress);
286
287 return numByteToBeWritten;
288 }
289
290 size_t DumpOutputTensorsToMemory(Model& model, uint8_t* memAddress, const size_t memSize)
291 {
292 const size_t numOutputs = model.GetNumOutputs();
293 size_t numBytesWritten = 0;
294 uint8_t* ptr = memAddress;
295
296 /* Iterate over all output tensors. */
297 for (size_t i = 0; i < numOutputs; ++i) {
298 const TfLiteTensor* tensor = model.GetOutputTensor(i);
299 const auto* tData = tflite::GetTensorData<uint8_t>(tensor);
300#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100301 DumpTensor(tensor);
Richard Burton00553462021-11-10 16:27:14 +0000302#endif /* VERIFY_TEST_OUTPUT */
303 /* Ensure that we don't overflow the allowed limit. */
304 if (numBytesWritten + tensor->bytes <= memSize) {
305 if (tensor->bytes > 0) {
306 std::memcpy(ptr, tData, tensor->bytes);
307
308 info("Copied %zu bytes for tensor %zu to 0x%p\n",
309 tensor->bytes, i, ptr);
310
311 numBytesWritten += tensor->bytes;
312 ptr += tensor->bytes;
313 }
314 } else {
315 printf_err("Error writing tensor %zu to memory @ 0x%p\n",
316 i, memAddress);
317 break;
318 }
319 }
320
321 info("%zu bytes written to memory @ 0x%p\n", numBytesWritten, memAddress);
322
323 return numBytesWritten;
324 }
325
326 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
327 {
328 auto curClipIdx = ctx.Get<uint32_t>("clipIndex");
329 if (curClipIdx + 1 >= NUMBER_OF_FILES) {
330 ctx.Set<uint32_t>("clipIndex", 0);
331 return;
332 }
333 ++curClipIdx;
334 ctx.Set<uint32_t>("clipIndex", curClipIdx);
335 }
336
Richard Burton00553462021-11-10 16:27:14 +0000337} /* namespace app */
338} /* namespace arm */