blob: c2d2ea4ededc8227e4187a703f884306e0659431 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "InputFiles.hpp"
20#include "Classifier.hpp"
Kshitij Sisodia76a15802021-12-24 11:05:11 +000021#include "MicroNetKwsModel.hpp"
alexander3c798932021-03-26 21:42:19 +000022#include "hal.h"
Kshitij Sisodia76a15802021-12-24 11:05:11 +000023#include "MicroNetKwsMfcc.hpp"
alexander3c798932021-03-26 21:42:19 +000024#include "AudioUtils.hpp"
25#include "UseCaseCommonUtils.hpp"
26#include "KwsResult.hpp"
alexander31ae9f02022-02-10 16:15:54 +000027#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000028
29#include <vector>
30#include <functional>
31
32using KwsClassifier = arm::app::Classifier;
33
34namespace arm {
35namespace app {
36
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010037
alexander3c798932021-03-26 21:42:19 +000038 /**
39 * @brief Presents inference results using the data presentation
40 * object.
41 * @param[in] platform Reference to the hal platform object.
42 * @param[in] results Vector of classification results to be displayed.
alexander3c798932021-03-26 21:42:19 +000043 * @return true if successful, false otherwise.
44 **/
alexanderc350cdc2021-04-29 20:36:09 +010045 static bool PresentInferenceResult(hal_platform& platform,
46 const std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000047
48 /**
49 * @brief Returns a function to perform feature calculation and populates input tensor data with
50 * MFCC data.
51 *
52 * Input tensor data type check is performed to choose correct MFCC feature data type.
53 * If tensor has an integer data type then original features are quantised.
54 *
55 * Warning: MFCC calculator provided as input must have the same life scope as returned function.
56 *
57 * @param[in] mfcc MFCC feature calculator.
58 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
59 * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
60 * @return Function to be called providing audio sample and sliding window index.
61 */
62 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
Kshitij Sisodia76a15802021-12-24 11:05:11 +000063 GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
alexander3c798932021-03-26 21:42:19 +000064 TfLiteTensor* inputTensor,
65 size_t cacheSize);
66
67 /* Audio inference handler. */
68 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
69 {
70 auto& platform = ctx.Get<hal_platform&>("platform");
Isabella Gottardi8df12f32021-04-07 17:15:31 +010071 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +000072
73 constexpr uint32_t dataPsnTxtInfStartX = 20;
74 constexpr uint32_t dataPsnTxtInfStartY = 40;
75 constexpr int minTensorDims = static_cast<int>(
Kshitij Sisodia76a15802021-12-24 11:05:11 +000076 (arm::app::MicroNetKwsModel::ms_inputRowsIdx > arm::app::MicroNetKwsModel::ms_inputColsIdx)?
77 arm::app::MicroNetKwsModel::ms_inputRowsIdx : arm::app::MicroNetKwsModel::ms_inputColsIdx);
alexander3c798932021-03-26 21:42:19 +000078
alexander3c798932021-03-26 21:42:19 +000079 auto& model = ctx.Get<Model&>("model");
80
81 /* If the request has a valid size, set the audio index. */
82 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010083 if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
alexander3c798932021-03-26 21:42:19 +000084 return false;
85 }
86 }
87 if (!model.IsInited()) {
88 printf_err("Model is not initialised! Terminating processing.\n");
89 return false;
90 }
91
92 const auto frameLength = ctx.Get<int>("frameLength");
93 const auto frameStride = ctx.Get<int>("frameStride");
94 const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
95 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
96
97 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
98 TfLiteTensor* inputTensor = model.GetInputTensor(0);
99
100 if (!inputTensor->dims) {
101 printf_err("Invalid input tensor dims\n");
102 return false;
103 } else if (inputTensor->dims->size < minTensorDims) {
104 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
105 return false;
106 }
107
108 TfLiteIntArray* inputShape = model.GetInputShape(0);
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000109 const uint32_t kNumCols = inputShape->data[arm::app::MicroNetKwsModel::ms_inputColsIdx];
110 const uint32_t kNumRows = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx];
alexander3c798932021-03-26 21:42:19 +0000111
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000112 audio::MicroNetKwsMFCC mfcc = audio::MicroNetKwsMFCC(kNumCols, frameLength);
alexander3c798932021-03-26 21:42:19 +0000113 mfcc.Init();
114
115 /* Deduce the data length required for 1 inference from the network parameters. */
116 auto audioDataWindowSize = kNumRows * frameStride + (frameLength - frameStride);
117 auto mfccWindowSize = frameLength;
118 auto mfccWindowStride = frameStride;
119
120 /* We choose to move by half the window size => for a 1 second window size
121 * there is an overlap of 0.5 seconds. */
122 auto audioDataStride = audioDataWindowSize / 2;
123
124 /* To have the previously calculated features re-usable, stride must be multiple
125 * of MFCC features window stride. */
126 if (0 != audioDataStride % mfccWindowStride) {
127
128 /* Reduce the stride. */
129 audioDataStride -= audioDataStride % mfccWindowStride;
130 }
131
132 auto nMfccVectorsInAudioStride = audioDataStride/mfccWindowStride;
133
134 /* We expect to be sampling 1 second worth of data at a time.
135 * NOTE: This is only used for time stamp calculation. */
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000136 const float secondsPerSample = 1.0/audio::MicroNetKwsMFCC::ms_defaultSamplingFreq;
alexander3c798932021-03-26 21:42:19 +0000137
138 do {
Richard Burton9b8d67a2021-12-10 12:32:51 +0000139 platform.data_psn->clear(COLOR_BLACK);
140
alexander3c798932021-03-26 21:42:19 +0000141 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
142
143 /* Creating a mfcc features sliding window for the data required for 1 inference. */
144 auto audioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
145 get_audio_array(currentIndex),
146 audioDataWindowSize, mfccWindowSize,
147 mfccWindowStride);
148
149 /* Creating a sliding window through the whole audio clip. */
150 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
151 get_audio_array(currentIndex),
152 get_audio_array_size(currentIndex),
153 audioDataWindowSize, audioDataStride);
154
155 /* Calculate number of the feature vectors in the window overlap region.
156 * These feature vectors will be reused.*/
157 auto numberOfReusedFeatureVectors = audioMFCCWindowSlider.TotalStrides() + 1
158 - nMfccVectorsInAudioStride;
159
160 /* Construct feature calculation function. */
161 auto mfccFeatureCalc = GetFeatureCalculator(mfcc, inputTensor,
162 numberOfReusedFeatureVectors);
163
164 if (!mfccFeatureCalc){
165 return false;
166 }
167
168 /* Declare a container for results. */
169 std::vector<arm::app::kws::KwsResult> results;
170
171 /* Display message on the LCD - inference running. */
172 std::string str_inf{"Running inference... "};
173 platform.data_psn->present_data_text(
174 str_inf.c_str(), str_inf.size(),
175 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100176 info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex,
alexander3c798932021-03-26 21:42:19 +0000177 get_filename(currentIndex));
178
179 /* Start sliding through audio clip. */
180 while (audioDataSlider.HasNext()) {
181 const int16_t *inferenceWindow = audioDataSlider.Next();
182
183 /* We moved to the next window - set the features sliding to the new address. */
184 audioMFCCWindowSlider.Reset(inferenceWindow);
185
186 /* The first window does not have cache ready. */
187 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
188
189 /* Start calculating features inside one audio sliding window. */
190 while (audioMFCCWindowSlider.HasNext()) {
191 const int16_t *mfccWindow = audioMFCCWindowSlider.Next();
192 std::vector<int16_t> mfccAudioData = std::vector<int16_t>(mfccWindow,
193 mfccWindow + mfccWindowSize);
194 /* Compute features for this window and write them to input tensor. */
195 mfccFeatureCalc(mfccAudioData,
196 audioMFCCWindowSlider.Index(),
197 useCache,
198 nMfccVectorsInAudioStride);
199 }
200
201 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
202 audioDataSlider.TotalStrides() + 1);
203
204 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100205 if (!RunInference(model, profiler)) {
206 return false;
207 }
alexander3c798932021-03-26 21:42:19 +0000208
209 std::vector<ClassificationResult> classificationResult;
210 auto& classifier = ctx.Get<KwsClassifier&>("classifier");
211 classifier.GetClassificationResults(outputTensor, classificationResult,
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000212 ctx.Get<std::vector<std::string>&>("labels"), 1, true);
alexander3c798932021-03-26 21:42:19 +0000213
214 results.emplace_back(kws::KwsResult(classificationResult,
215 audioDataSlider.Index() * secondsPerSample * audioDataStride,
216 audioDataSlider.Index(), scoreThreshold));
217
218#if VERIFY_TEST_OUTPUT
219 arm::app::DumpTensor(outputTensor);
220#endif /* VERIFY_TEST_OUTPUT */
221 } /* while (audioDataSlider.HasNext()) */
222
223 /* Erase. */
224 str_inf = std::string(str_inf.size(), ' ');
225 platform.data_psn->present_data_text(
226 str_inf.c_str(), str_inf.size(),
227 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
228
229 ctx.Set<std::vector<arm::app::kws::KwsResult>>("results", results);
230
alexanderc350cdc2021-04-29 20:36:09 +0100231 if (!PresentInferenceResult(platform, results)) {
alexander3c798932021-03-26 21:42:19 +0000232 return false;
233 }
234
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100235 profiler.PrintProfilingResult();
236
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100237 IncrementAppCtxIfmIdx(ctx,"clipIndex");
alexander3c798932021-03-26 21:42:19 +0000238
239 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
240
241 return true;
242 }
243
alexanderc350cdc2021-04-29 20:36:09 +0100244 static bool PresentInferenceResult(hal_platform& platform,
245 const std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000246 {
247 constexpr uint32_t dataPsnTxtStartX1 = 20;
248 constexpr uint32_t dataPsnTxtStartY1 = 30;
249 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
250
251 platform.data_psn->set_text_color(COLOR_GREEN);
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100252 info("Final results:\n");
253 info("Total number of inferences: %zu\n", results.size());
alexander3c798932021-03-26 21:42:19 +0000254
255 /* Display each result */
256 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
257
258 for (uint32_t i = 0; i < results.size(); ++i) {
259
260 std::string topKeyword{"<none>"};
261 float score = 0.f;
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100262 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000263 topKeyword = results[i].m_resultVec[0].m_label;
264 score = results[i].m_resultVec[0].m_normalisedVal;
265 }
266
267 std::string resultStr =
268 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
269 std::string{"s: "} + topKeyword + std::string{" ("} +
270 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
271
272 platform.data_psn->present_data_text(
273 resultStr.c_str(), resultStr.size(),
274 dataPsnTxtStartX1, rowIdx1, false);
275 rowIdx1 += dataPsnTxtYIncr;
276
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100277 if (results[i].m_resultVec.empty()) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100278 info("For timestamp: %f (inference #: %" PRIu32
279 "); label: %s; threshold: %f\n",
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100280 results[i].m_timeStamp, results[i].m_inferenceNumber,
281 topKeyword.c_str(),
282 results[i].m_threshold);
283 } else {
284 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100285 info("For timestamp: %f (inference #: %" PRIu32
286 "); label: %s, score: %f; threshold: %f\n",
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100287 results[i].m_timeStamp,
288 results[i].m_inferenceNumber,
289 results[i].m_resultVec[j].m_label.c_str(),
290 results[i].m_resultVec[j].m_normalisedVal,
291 results[i].m_threshold);
292 }
alexander3c798932021-03-26 21:42:19 +0000293 }
294 }
295
296 return true;
297 }
298
299 /**
300 * @brief Generic feature calculator factory.
301 *
302 * Returns lambda function to compute features using features cache.
303 * Real features math is done by a lambda function provided as a parameter.
304 * Features are written to input tensor memory.
305 *
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100306 * @tparam T Feature vector type.
307 * @param[in] inputTensor Model input tensor pointer.
308 * @param[in] cacheSize Number of feature vectors to cache. Defined by the sliding window overlap.
309 * @param[in] compute Features calculator function.
310 * @return Lambda function to compute features.
alexander3c798932021-03-26 21:42:19 +0000311 */
312 template<class T>
313 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100314 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
315 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000316 {
317 /* Feature cache to be captured by lambda function. */
318 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
319
320 return [=](std::vector<int16_t>& audioDataWindow,
321 size_t index,
322 bool useCache,
323 size_t featuresOverlapIndex)
324 {
325 T *tensorData = tflite::GetTensorData<T>(inputTensor);
326 std::vector<T> features;
327
328 /* Reuse features from cache if cache is ready and sliding windows overlap.
329 * Overlap is in the beginning of sliding window with a size of a feature cache. */
330 if (useCache && index < featureCache.size()) {
331 features = std::move(featureCache[index]);
332 } else {
333 features = std::move(compute(audioDataWindow));
334 }
335 auto size = features.size();
336 auto sizeBytes = sizeof(T) * size;
337 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
338
339 /* Start renewing cache as soon iteration goes out of the windows overlap. */
340 if (index >= featuresOverlapIndex) {
341 featureCache[index - featuresOverlapIndex] = std::move(features);
342 }
343 };
344 }
345
346 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100347 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000348 size_t cacheSize,
349 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
350
351 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100352 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
353 size_t cacheSize,
354 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000355
356 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100357 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
358 size_t cacheSize,
359 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000360
361 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100362 FeatureCalc<float>(TfLiteTensor* inputTensor,
363 size_t cacheSize,
364 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000365
366
367 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
Kshitij Sisodia76a15802021-12-24 11:05:11 +0000368 GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
alexander3c798932021-03-26 21:42:19 +0000369 {
370 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
371
372 TfLiteQuantization quant = inputTensor->quantization;
373
374 if (kTfLiteAffineQuantization == quant.type) {
375
376 auto *quantParams = (TfLiteAffineQuantization *) quant.params;
377 const float quantScale = quantParams->scale->data[0];
378 const int quantOffset = quantParams->zero_point->data[0];
379
380 switch (inputTensor->type) {
381 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100382 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
383 cacheSize,
384 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
385 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
386 quantScale,
387 quantOffset);
388 }
alexander3c798932021-03-26 21:42:19 +0000389 );
390 break;
391 }
392 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100393 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
394 cacheSize,
alexander3c798932021-03-26 21:42:19 +0000395 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
396 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
397 quantScale,
398 quantOffset);
399 }
400 );
401 break;
402 }
403 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100404 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
405 cacheSize,
406 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
407 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
408 quantScale,
409 quantOffset);
410 }
alexander3c798932021-03-26 21:42:19 +0000411 );
412 break;
413 }
414 default:
415 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
416 }
417
418
419 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100420 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
421 cacheSize,
422 [&mfcc](std::vector<int16_t>& audioDataWindow) {
423 return mfcc.MfccCompute(audioDataWindow);
424 });
alexander3c798932021-03-26 21:42:19 +0000425 }
426 return mfccFeatureCalc;
427 }
428
429} /* namespace app */
430} /* namespace arm */