blob: 0c5984c85a3dbc860d35e772a54cb63c8614afeb [file] [log] [blame]
Richard Burton00553462021-11-10 16:27:14 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
Richard Burton00553462021-11-10 16:27:14 +000017#include "hal.h"
alexander31ae9f02022-02-10 16:15:54 +000018#include "UseCaseHandler.hpp"
Richard Burton00553462021-11-10 16:27:14 +000019#include "UseCaseCommonUtils.hpp"
20#include "AudioUtils.hpp"
21#include "InputFiles.hpp"
22#include "RNNoiseModel.hpp"
23#include "RNNoiseProcess.hpp"
alexander31ae9f02022-02-10 16:15:54 +000024#include "log_macros.h"
25
26#include <cmath>
27#include <algorithm>
Richard Burton00553462021-11-10 16:27:14 +000028
29namespace arm {
30namespace app {
31
32 /**
33 * @brief Helper function to increment current audio clip features index.
34 * @param[in,out] ctx Pointer to the application context object.
35 **/
36 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
37
38 /**
39 * @brief Quantize the given features and populate the input Tensor.
40 * @param[in] inputFeatures Vector of floating point features to quantize.
41 * @param[in] quantScale Quantization scale for the inputTensor.
42 * @param[in] quantOffset Quantization offset for the inputTensor.
43 * @param[in,out] inputTensor TFLite micro tensor to populate.
44 **/
45 static void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
46 float quantScale, int quantOffset,
47 TfLiteTensor* inputTensor);
48
49 /* Noise reduction inference handler. */
50 bool NoiseReductionHandler(ApplicationContext& ctx, bool runAll)
51 {
52 constexpr uint32_t dataPsnTxtInfStartX = 20;
53 constexpr uint32_t dataPsnTxtInfStartY = 40;
54
55 /* Variables used for memory dumping. */
56 size_t memDumpMaxLen = 0;
57 uint8_t* memDumpBaseAddr = nullptr;
58 size_t undefMemDumpBytesWritten = 0;
59 size_t *pMemDumpBytesWritten = &undefMemDumpBytesWritten;
60 if (ctx.Has("MEM_DUMP_LEN") && ctx.Has("MEM_DUMP_BASE_ADDR") && ctx.Has("MEM_DUMP_BYTE_WRITTEN")) {
61 memDumpMaxLen = ctx.Get<size_t>("MEM_DUMP_LEN");
62 memDumpBaseAddr = ctx.Get<uint8_t*>("MEM_DUMP_BASE_ADDR");
63 pMemDumpBytesWritten = ctx.Get<size_t*>("MEM_DUMP_BYTE_WRITTEN");
64 }
65 std::reference_wrapper<size_t> memDumpBytesWritten = std::ref(*pMemDumpBytesWritten);
66
67 auto& platform = ctx.Get<hal_platform&>("platform");
Richard Burton00553462021-11-10 16:27:14 +000068 auto& profiler = ctx.Get<Profiler&>("profiler");
69
70 /* Get model reference. */
71 auto& model = ctx.Get<RNNoiseModel&>("model");
72 if (!model.IsInited()) {
73 printf_err("Model is not initialised! Terminating processing.\n");
74 return false;
75 }
76
77 /* Populate Pre-Processing related parameters. */
78 auto audioParamsWinLen = ctx.Get<uint32_t>("frameLength");
79 auto audioParamsWinStride = ctx.Get<uint32_t>("frameStride");
80 auto nrNumInputFeatures = ctx.Get<uint32_t>("numInputFeatures");
81
82 TfLiteTensor* inputTensor = model.GetInputTensor(0);
83 if (nrNumInputFeatures != inputTensor->bytes) {
84 printf_err("Input features size must be equal to input tensor size."
85 " Feature size = %" PRIu32 ", Tensor size = %zu.\n",
86 nrNumInputFeatures, inputTensor->bytes);
87 return false;
88 }
89
90 TfLiteTensor* outputTensor = model.GetOutputTensor(model.m_indexForModelOutput);
91
92 /* Initial choice of index for WAV file. */
93 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
94
95 std::function<const int16_t* (const uint32_t)> audioAccessorFunc = get_audio_array;
96 if (ctx.Has("features")) {
97 audioAccessorFunc = ctx.Get<std::function<const int16_t* (const uint32_t)>>("features");
98 }
99 std::function<uint32_t (const uint32_t)> audioSizeAccessorFunc = get_audio_array_size;
100 if (ctx.Has("featureSizes")) {
101 audioSizeAccessorFunc = ctx.Get<std::function<uint32_t (const uint32_t)>>("featureSizes");
102 }
103 std::function<const char*(const uint32_t)> audioFileAccessorFunc = get_filename;
104 if (ctx.Has("featureFileNames")) {
105 audioFileAccessorFunc = ctx.Get<std::function<const char*(const uint32_t)>>("featureFileNames");
106 }
107 do{
Richard Burton9b8d67a2021-12-10 12:32:51 +0000108 platform.data_psn->clear(COLOR_BLACK);
109
Richard Burton00553462021-11-10 16:27:14 +0000110 auto startDumpAddress = memDumpBaseAddr + memDumpBytesWritten;
111 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
112
113 /* Creating a sliding window through the audio. */
114 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
115 audioAccessorFunc(currentIndex),
116 audioSizeAccessorFunc(currentIndex), audioParamsWinLen,
117 audioParamsWinStride);
118
119 info("Running inference on input feature map %" PRIu32 " => %s\n", currentIndex,
120 audioFileAccessorFunc(currentIndex));
121
122 memDumpBytesWritten += DumpDenoisedAudioHeader(audioFileAccessorFunc(currentIndex),
123 (audioDataSlider.TotalStrides() + 1) * audioParamsWinLen,
124 memDumpBaseAddr + memDumpBytesWritten,
125 memDumpMaxLen - memDumpBytesWritten);
126
127 rnn::RNNoiseProcess featureProcessor = rnn::RNNoiseProcess();
128 rnn::vec1D32F audioFrame(audioParamsWinLen);
129 rnn::vec1D32F inputFeatures(nrNumInputFeatures);
130 rnn::vec1D32F denoisedAudioFrameFloat(audioParamsWinLen);
131 std::vector<int16_t> denoisedAudioFrame(audioParamsWinLen);
132
133 std::vector<float> modelOutputFloat(outputTensor->bytes);
134 rnn::FrameFeatures frameFeatures;
135 bool resetGRU = true;
136
137 while (audioDataSlider.HasNext()) {
138 const int16_t* inferenceWindow = audioDataSlider.Next();
139 audioFrame = rnn::vec1D32F(inferenceWindow, inferenceWindow+audioParamsWinLen);
140
141 featureProcessor.PreprocessFrame(audioFrame.data(), audioParamsWinLen, frameFeatures);
142
143 /* Reset or copy over GRU states first to avoid TFLu memory overlap issues. */
144 if (resetGRU){
145 model.ResetGruState();
146 } else {
147 /* Copying gru state outputs to gru state inputs.
148 * Call ResetGruState in between the sequence of inferences on unrelated input data. */
149 model.CopyGruStates();
150 }
151
152 QuantizeAndPopulateInput(frameFeatures.m_featuresVec,
153 inputTensor->params.scale, inputTensor->params.zero_point,
154 inputTensor);
155
156 /* Strings for presentation/logging. */
157 std::string str_inf{"Running inference... "};
158
159 /* Display message on the LCD - inference running. */
160 platform.data_psn->present_data_text(
161 str_inf.c_str(), str_inf.size(),
162 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
163
164 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1, audioDataSlider.TotalStrides() + 1);
165
166 /* Run inference over this feature sliding window. */
167 profiler.StartProfiling("Inference");
168 bool success = model.RunInference();
169 profiler.StopProfiling();
170 resetGRU = false;
171
172 if (!success) {
173 return false;
174 }
175
176 /* De-quantize main model output ready for post-processing. */
177 const auto* outputData = tflite::GetTensorData<int8_t>(outputTensor);
178 auto outputQuantParams = arm::app::GetTensorQuantParams(outputTensor);
179
180 for (size_t i = 0; i < outputTensor->bytes; ++i) {
181 modelOutputFloat[i] = (static_cast<float>(outputData[i]) - outputQuantParams.offset)
182 * outputQuantParams.scale;
183 }
184
185 /* Round and cast the post-processed results for dumping to wav. */
186 featureProcessor.PostProcessFrame(modelOutputFloat, frameFeatures, denoisedAudioFrameFloat);
187 for (size_t i = 0; i < audioParamsWinLen; ++i) {
188 denoisedAudioFrame[i] = static_cast<int16_t>(std::roundf(denoisedAudioFrameFloat[i]));
189 }
190
191 /* Erase. */
192 str_inf = std::string(str_inf.size(), ' ');
193 platform.data_psn->present_data_text(
194 str_inf.c_str(), str_inf.size(),
195 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
196
197 if (memDumpMaxLen > 0) {
198 /* Dump output tensors to memory. */
199 memDumpBytesWritten += DumpOutputDenoisedAudioFrame(
200 denoisedAudioFrame,
201 memDumpBaseAddr + memDumpBytesWritten,
202 memDumpMaxLen - memDumpBytesWritten);
203 }
204 }
205
206 if (memDumpMaxLen > 0) {
207 /* Needed to not let the compiler complain about type mismatch. */
208 size_t valMemDumpBytesWritten = memDumpBytesWritten;
209 info("Output memory dump of %zu bytes written at address 0x%p\n",
210 valMemDumpBytesWritten, startDumpAddress);
211 }
212
213 DumpDenoisedAudioFooter(memDumpBaseAddr + memDumpBytesWritten, memDumpMaxLen - memDumpBytesWritten);
214
Richard Burton9b8d67a2021-12-10 12:32:51 +0000215 info("All inferences for audio clip complete.\n");
Richard Burton00553462021-11-10 16:27:14 +0000216 profiler.PrintProfilingResult();
217 IncrementAppCtxClipIdx(ctx);
218
Ayaan Masood233cec02021-12-09 17:22:22 +0000219 std::string clearString{' '};
220 platform.data_psn->present_data_text(
221 clearString.c_str(), clearString.size(),
222 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
223
224 std::string completeMsg{"Inference complete!"};
225
226 /* Display message on the LCD - inference complete. */
227 platform.data_psn->present_data_text(
228 completeMsg.c_str(), completeMsg.size(),
229 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
230
Richard Burton00553462021-11-10 16:27:14 +0000231 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
232
233 return true;
234 }
235
236 size_t DumpDenoisedAudioHeader(const char* filename, size_t dumpSize,
237 uint8_t *memAddress, size_t memSize){
238
239 if (memAddress == nullptr){
240 return 0;
241 }
242
243 int32_t filenameLength = strlen(filename);
244 size_t numBytesWritten = 0;
245 size_t numBytesToWrite = 0;
246 int32_t dumpSizeByte = dumpSize * sizeof(int16_t);
247 bool overflow = false;
248
249 /* Write the filename length */
250 numBytesToWrite = sizeof(filenameLength);
251 if (memSize - numBytesToWrite > 0) {
252 std::memcpy(memAddress, &filenameLength, numBytesToWrite);
253 numBytesWritten += numBytesToWrite;
254 memSize -= numBytesWritten;
255 } else {
256 overflow = true;
257 }
258
259 /* Write file name */
260 numBytesToWrite = filenameLength;
261 if(memSize - numBytesToWrite > 0) {
262 std::memcpy(memAddress + numBytesWritten, filename, numBytesToWrite);
263 numBytesWritten += numBytesToWrite;
264 memSize -= numBytesWritten;
265 } else {
266 overflow = true;
267 }
268
269 /* Write dumpSize in byte */
270 numBytesToWrite = sizeof(dumpSizeByte);
271 if(memSize - numBytesToWrite > 0) {
272 std::memcpy(memAddress + numBytesWritten, &(dumpSizeByte), numBytesToWrite);
273 numBytesWritten += numBytesToWrite;
274 memSize -= numBytesWritten;
275 } else {
276 overflow = true;
277 }
278
279 if(false == overflow) {
280 info("Audio Clip dump header info (%zu bytes) written to %p\n", numBytesWritten, memAddress);
281 } else {
282 printf_err("Not enough memory to dump Audio Clip header.\n");
283 }
284
285 return numBytesWritten;
286 }
287
288 size_t DumpDenoisedAudioFooter(uint8_t *memAddress, size_t memSize){
289 if ((memAddress == nullptr) || (memSize < 4)) {
290 return 0;
291 }
292 const int32_t eofMarker = -1;
293 std::memcpy(memAddress, &eofMarker, sizeof(int32_t));
294
295 return sizeof(int32_t);
296 }
297
298 size_t DumpOutputDenoisedAudioFrame(const std::vector<int16_t> &audioFrame,
299 uint8_t *memAddress, size_t memSize)
300 {
301 if (memAddress == nullptr) {
302 return 0;
303 }
304
305 size_t numByteToBeWritten = audioFrame.size() * sizeof(int16_t);
306 if( numByteToBeWritten > memSize) {
George Gekova2b0fc22021-11-08 16:30:43 +0000307 printf_err("Overflow error: Writing %zu of %zu bytes to memory @ 0x%p.\n", memSize, numByteToBeWritten, memAddress);
Richard Burton00553462021-11-10 16:27:14 +0000308 numByteToBeWritten = memSize;
309 }
310
311 std::memcpy(memAddress, audioFrame.data(), numByteToBeWritten);
312 info("Copied %zu bytes to %p\n", numByteToBeWritten, memAddress);
313
314 return numByteToBeWritten;
315 }
316
317 size_t DumpOutputTensorsToMemory(Model& model, uint8_t* memAddress, const size_t memSize)
318 {
319 const size_t numOutputs = model.GetNumOutputs();
320 size_t numBytesWritten = 0;
321 uint8_t* ptr = memAddress;
322
323 /* Iterate over all output tensors. */
324 for (size_t i = 0; i < numOutputs; ++i) {
325 const TfLiteTensor* tensor = model.GetOutputTensor(i);
326 const auto* tData = tflite::GetTensorData<uint8_t>(tensor);
327#if VERIFY_TEST_OUTPUT
328 arm::app::DumpTensor(tensor);
329#endif /* VERIFY_TEST_OUTPUT */
330 /* Ensure that we don't overflow the allowed limit. */
331 if (numBytesWritten + tensor->bytes <= memSize) {
332 if (tensor->bytes > 0) {
333 std::memcpy(ptr, tData, tensor->bytes);
334
335 info("Copied %zu bytes for tensor %zu to 0x%p\n",
336 tensor->bytes, i, ptr);
337
338 numBytesWritten += tensor->bytes;
339 ptr += tensor->bytes;
340 }
341 } else {
342 printf_err("Error writing tensor %zu to memory @ 0x%p\n",
343 i, memAddress);
344 break;
345 }
346 }
347
348 info("%zu bytes written to memory @ 0x%p\n", numBytesWritten, memAddress);
349
350 return numBytesWritten;
351 }
352
353 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
354 {
355 auto curClipIdx = ctx.Get<uint32_t>("clipIndex");
356 if (curClipIdx + 1 >= NUMBER_OF_FILES) {
357 ctx.Set<uint32_t>("clipIndex", 0);
358 return;
359 }
360 ++curClipIdx;
361 ctx.Set<uint32_t>("clipIndex", curClipIdx);
362 }
363
364 void QuantizeAndPopulateInput(rnn::vec1D32F& inputFeatures,
365 const float quantScale, const int quantOffset, TfLiteTensor* inputTensor)
366 {
367 const float minVal = std::numeric_limits<int8_t>::min();
368 const float maxVal = std::numeric_limits<int8_t>::max();
369
370 auto* inputTensorData = tflite::GetTensorData<int8_t>(inputTensor);
371
372 for (size_t i=0; i < inputFeatures.size(); ++i) {
373 float quantValue = ((inputFeatures[i] / quantScale) + quantOffset);
374 inputTensorData[i] = static_cast<int8_t>(std::min<float>(std::max<float>(quantValue, minVal), maxVal));
375 }
376 }
377
378
379} /* namespace app */
380} /* namespace arm */