blob: 2144c03ac075b810b1ebf660516c520ffe9adf93 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "InputFiles.hpp"
20#include "Classifier.hpp"
21#include "DsCnnModel.hpp"
22#include "hal.h"
23#include "DsCnnMfcc.hpp"
24#include "AudioUtils.hpp"
25#include "UseCaseCommonUtils.hpp"
26#include "KwsResult.hpp"
27
28#include <vector>
29#include <functional>
30
31using KwsClassifier = arm::app::Classifier;
32
33namespace arm {
34namespace app {
35
36 /**
37 * @brief Helper function to increment current audio clip index.
38 * @param[in,out] ctx Pointer to the application context object.
39 **/
alexanderc350cdc2021-04-29 20:36:09 +010040 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
alexander3c798932021-03-26 21:42:19 +000041
42 /**
43 * @brief Helper function to set the audio clip index.
44 * @param[in,out] ctx Pointer to the application context object.
45 * @param[in] idx Value to be set.
46 * @return true if index is set, false otherwise.
47 **/
alexanderc350cdc2021-04-29 20:36:09 +010048 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
alexander3c798932021-03-26 21:42:19 +000049
50 /**
51 * @brief Presents inference results using the data presentation
52 * object.
53 * @param[in] platform Reference to the hal platform object.
54 * @param[in] results Vector of classification results to be displayed.
alexander3c798932021-03-26 21:42:19 +000055 * @return true if successful, false otherwise.
56 **/
alexanderc350cdc2021-04-29 20:36:09 +010057 static bool PresentInferenceResult(hal_platform& platform,
58 const std::vector<arm::app::kws::KwsResult>& results);
alexander3c798932021-03-26 21:42:19 +000059
60 /**
61 * @brief Returns a function to perform feature calculation and populates input tensor data with
62 * MFCC data.
63 *
64 * Input tensor data type check is performed to choose correct MFCC feature data type.
65 * If tensor has an integer data type then original features are quantised.
66 *
67 * Warning: MFCC calculator provided as input must have the same life scope as returned function.
68 *
69 * @param[in] mfcc MFCC feature calculator.
70 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
71 * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
72 * @return Function to be called providing audio sample and sliding window index.
73 */
74 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
75 GetFeatureCalculator(audio::DsCnnMFCC& mfcc,
76 TfLiteTensor* inputTensor,
77 size_t cacheSize);
78
79 /* Audio inference handler. */
80 bool ClassifyAudioHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
81 {
82 auto& platform = ctx.Get<hal_platform&>("platform");
Isabella Gottardi8df12f32021-04-07 17:15:31 +010083 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +000084
85 constexpr uint32_t dataPsnTxtInfStartX = 20;
86 constexpr uint32_t dataPsnTxtInfStartY = 40;
87 constexpr int minTensorDims = static_cast<int>(
88 (arm::app::DsCnnModel::ms_inputRowsIdx > arm::app::DsCnnModel::ms_inputColsIdx)?
89 arm::app::DsCnnModel::ms_inputRowsIdx : arm::app::DsCnnModel::ms_inputColsIdx);
90
91 platform.data_psn->clear(COLOR_BLACK);
92
93 auto& model = ctx.Get<Model&>("model");
94
95 /* If the request has a valid size, set the audio index. */
96 if (clipIndex < NUMBER_OF_FILES) {
alexanderc350cdc2021-04-29 20:36:09 +010097 if (!SetAppCtxClipIdx(ctx, clipIndex)) {
alexander3c798932021-03-26 21:42:19 +000098 return false;
99 }
100 }
101 if (!model.IsInited()) {
102 printf_err("Model is not initialised! Terminating processing.\n");
103 return false;
104 }
105
106 const auto frameLength = ctx.Get<int>("frameLength");
107 const auto frameStride = ctx.Get<int>("frameStride");
108 const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
109 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
110
111 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
112 TfLiteTensor* inputTensor = model.GetInputTensor(0);
113
114 if (!inputTensor->dims) {
115 printf_err("Invalid input tensor dims\n");
116 return false;
117 } else if (inputTensor->dims->size < minTensorDims) {
118 printf_err("Input tensor dimension should be >= %d\n", minTensorDims);
119 return false;
120 }
121
122 TfLiteIntArray* inputShape = model.GetInputShape(0);
123 const uint32_t kNumCols = inputShape->data[arm::app::DsCnnModel::ms_inputColsIdx];
124 const uint32_t kNumRows = inputShape->data[arm::app::DsCnnModel::ms_inputRowsIdx];
125
126 audio::DsCnnMFCC mfcc = audio::DsCnnMFCC(kNumCols, frameLength);
127 mfcc.Init();
128
129 /* Deduce the data length required for 1 inference from the network parameters. */
130 auto audioDataWindowSize = kNumRows * frameStride + (frameLength - frameStride);
131 auto mfccWindowSize = frameLength;
132 auto mfccWindowStride = frameStride;
133
134 /* We choose to move by half the window size => for a 1 second window size
135 * there is an overlap of 0.5 seconds. */
136 auto audioDataStride = audioDataWindowSize / 2;
137
138 /* To have the previously calculated features re-usable, stride must be multiple
139 * of MFCC features window stride. */
140 if (0 != audioDataStride % mfccWindowStride) {
141
142 /* Reduce the stride. */
143 audioDataStride -= audioDataStride % mfccWindowStride;
144 }
145
146 auto nMfccVectorsInAudioStride = audioDataStride/mfccWindowStride;
147
148 /* We expect to be sampling 1 second worth of data at a time.
149 * NOTE: This is only used for time stamp calculation. */
150 const float secondsPerSample = 1.0/audio::DsCnnMFCC::ms_defaultSamplingFreq;
151
152 do {
153 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
154
155 /* Creating a mfcc features sliding window for the data required for 1 inference. */
156 auto audioMFCCWindowSlider = audio::SlidingWindow<const int16_t>(
157 get_audio_array(currentIndex),
158 audioDataWindowSize, mfccWindowSize,
159 mfccWindowStride);
160
161 /* Creating a sliding window through the whole audio clip. */
162 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
163 get_audio_array(currentIndex),
164 get_audio_array_size(currentIndex),
165 audioDataWindowSize, audioDataStride);
166
167 /* Calculate number of the feature vectors in the window overlap region.
168 * These feature vectors will be reused.*/
169 auto numberOfReusedFeatureVectors = audioMFCCWindowSlider.TotalStrides() + 1
170 - nMfccVectorsInAudioStride;
171
172 /* Construct feature calculation function. */
173 auto mfccFeatureCalc = GetFeatureCalculator(mfcc, inputTensor,
174 numberOfReusedFeatureVectors);
175
176 if (!mfccFeatureCalc){
177 return false;
178 }
179
180 /* Declare a container for results. */
181 std::vector<arm::app::kws::KwsResult> results;
182
183 /* Display message on the LCD - inference running. */
184 std::string str_inf{"Running inference... "};
185 platform.data_psn->present_data_text(
186 str_inf.c_str(), str_inf.size(),
187 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100188 info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex,
alexander3c798932021-03-26 21:42:19 +0000189 get_filename(currentIndex));
190
191 /* Start sliding through audio clip. */
192 while (audioDataSlider.HasNext()) {
193 const int16_t *inferenceWindow = audioDataSlider.Next();
194
195 /* We moved to the next window - set the features sliding to the new address. */
196 audioMFCCWindowSlider.Reset(inferenceWindow);
197
198 /* The first window does not have cache ready. */
199 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
200
201 /* Start calculating features inside one audio sliding window. */
202 while (audioMFCCWindowSlider.HasNext()) {
203 const int16_t *mfccWindow = audioMFCCWindowSlider.Next();
204 std::vector<int16_t> mfccAudioData = std::vector<int16_t>(mfccWindow,
205 mfccWindow + mfccWindowSize);
206 /* Compute features for this window and write them to input tensor. */
207 mfccFeatureCalc(mfccAudioData,
208 audioMFCCWindowSlider.Index(),
209 useCache,
210 nMfccVectorsInAudioStride);
211 }
212
213 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
214 audioDataSlider.TotalStrides() + 1);
215
216 /* Run inference over this audio clip sliding window. */
alexander27b62d92021-05-04 20:46:08 +0100217 if (!RunInference(model, profiler)) {
218 return false;
219 }
alexander3c798932021-03-26 21:42:19 +0000220
221 std::vector<ClassificationResult> classificationResult;
222 auto& classifier = ctx.Get<KwsClassifier&>("classifier");
223 classifier.GetClassificationResults(outputTensor, classificationResult,
224 ctx.Get<std::vector<std::string>&>("labels"), 1);
225
226 results.emplace_back(kws::KwsResult(classificationResult,
227 audioDataSlider.Index() * secondsPerSample * audioDataStride,
228 audioDataSlider.Index(), scoreThreshold));
229
230#if VERIFY_TEST_OUTPUT
231 arm::app::DumpTensor(outputTensor);
232#endif /* VERIFY_TEST_OUTPUT */
233 } /* while (audioDataSlider.HasNext()) */
234
235 /* Erase. */
236 str_inf = std::string(str_inf.size(), ' ');
237 platform.data_psn->present_data_text(
238 str_inf.c_str(), str_inf.size(),
239 dataPsnTxtInfStartX, dataPsnTxtInfStartY, false);
240
241 ctx.Set<std::vector<arm::app::kws::KwsResult>>("results", results);
242
alexanderc350cdc2021-04-29 20:36:09 +0100243 if (!PresentInferenceResult(platform, results)) {
alexander3c798932021-03-26 21:42:19 +0000244 return false;
245 }
246
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100247 profiler.PrintProfilingResult();
248
alexanderc350cdc2021-04-29 20:36:09 +0100249 IncrementAppCtxClipIdx(ctx);
alexander3c798932021-03-26 21:42:19 +0000250
251 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
252
253 return true;
254 }
255
alexanderc350cdc2021-04-29 20:36:09 +0100256 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
alexander3c798932021-03-26 21:42:19 +0000257 {
258 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
259
260 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
261 ctx.Set<uint32_t>("clipIndex", 0);
262 return;
263 }
264 ++curAudioIdx;
265 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
266 }
267
alexanderc350cdc2021-04-29 20:36:09 +0100268 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx)
alexander3c798932021-03-26 21:42:19 +0000269 {
270 if (idx >= NUMBER_OF_FILES) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100271 printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n",
alexander3c798932021-03-26 21:42:19 +0000272 idx, NUMBER_OF_FILES);
273 return false;
274 }
275 ctx.Set<uint32_t>("clipIndex", idx);
276 return true;
277 }
278
alexanderc350cdc2021-04-29 20:36:09 +0100279 static bool PresentInferenceResult(hal_platform& platform,
280 const std::vector<arm::app::kws::KwsResult>& results)
alexander3c798932021-03-26 21:42:19 +0000281 {
282 constexpr uint32_t dataPsnTxtStartX1 = 20;
283 constexpr uint32_t dataPsnTxtStartY1 = 30;
284 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment. */
285
286 platform.data_psn->set_text_color(COLOR_GREEN);
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100287 info("Final results:\n");
288 info("Total number of inferences: %zu\n", results.size());
alexander3c798932021-03-26 21:42:19 +0000289
290 /* Display each result */
291 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
292
293 for (uint32_t i = 0; i < results.size(); ++i) {
294
295 std::string topKeyword{"<none>"};
296 float score = 0.f;
297
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100298 if (!results[i].m_resultVec.empty()) {
alexander3c798932021-03-26 21:42:19 +0000299 topKeyword = results[i].m_resultVec[0].m_label;
300 score = results[i].m_resultVec[0].m_normalisedVal;
301 }
302
303 std::string resultStr =
304 std::string{"@"} + std::to_string(results[i].m_timeStamp) +
305 std::string{"s: "} + topKeyword + std::string{" ("} +
306 std::to_string(static_cast<int>(score * 100)) + std::string{"%)"};
307
308 platform.data_psn->present_data_text(
309 resultStr.c_str(), resultStr.size(),
310 dataPsnTxtStartX1, rowIdx1, false);
311 rowIdx1 += dataPsnTxtYIncr;
312
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100313 if (results[i].m_resultVec.empty()) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100314 info("For timestamp: %f (inference #: %" PRIu32
315 "); label: %s; threshold: %f\n",
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100316 results[i].m_timeStamp, results[i].m_inferenceNumber,
317 topKeyword.c_str(),
318 results[i].m_threshold);
319 } else {
320 for (uint32_t j = 0; j < results[i].m_resultVec.size(); ++j) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100321 info("For timestamp: %f (inference #: %" PRIu32
322 "); label: %s, score: %f; threshold: %f\n",
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100323 results[i].m_timeStamp,
324 results[i].m_inferenceNumber,
325 results[i].m_resultVec[j].m_label.c_str(),
326 results[i].m_resultVec[j].m_normalisedVal,
327 results[i].m_threshold);
328 }
alexander3c798932021-03-26 21:42:19 +0000329 }
330 }
331
332 return true;
333 }
334
335 /**
336 * @brief Generic feature calculator factory.
337 *
338 * Returns lambda function to compute features using features cache.
339 * Real features math is done by a lambda function provided as a parameter.
340 * Features are written to input tensor memory.
341 *
Isabella Gottardi56ee6202021-05-12 08:27:15 +0100342 * @tparam T Feature vector type.
343 * @param[in] inputTensor Model input tensor pointer.
344 * @param[in] cacheSize Number of feature vectors to cache. Defined by the sliding window overlap.
345 * @param[in] compute Features calculator function.
346 * @return Lambda function to compute features.
alexander3c798932021-03-26 21:42:19 +0000347 */
348 template<class T>
349 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100350 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
351 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000352 {
353 /* Feature cache to be captured by lambda function. */
354 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
355
356 return [=](std::vector<int16_t>& audioDataWindow,
357 size_t index,
358 bool useCache,
359 size_t featuresOverlapIndex)
360 {
361 T *tensorData = tflite::GetTensorData<T>(inputTensor);
362 std::vector<T> features;
363
364 /* Reuse features from cache if cache is ready and sliding windows overlap.
365 * Overlap is in the beginning of sliding window with a size of a feature cache. */
366 if (useCache && index < featureCache.size()) {
367 features = std::move(featureCache[index]);
368 } else {
369 features = std::move(compute(audioDataWindow));
370 }
371 auto size = features.size();
372 auto sizeBytes = sizeof(T) * size;
373 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
374
375 /* Start renewing cache as soon iteration goes out of the windows overlap. */
376 if (index >= featuresOverlapIndex) {
377 featureCache[index - featuresOverlapIndex] = std::move(features);
378 }
379 };
380 }
381
382 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100383 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000384 size_t cacheSize,
385 std::function<std::vector<int8_t> (std::vector<int16_t>& )> compute);
386
387 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100388 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
389 size_t cacheSize,
390 std::function<std::vector<uint8_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000391
392 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100393 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
394 size_t cacheSize,
395 std::function<std::vector<int16_t> (std::vector<int16_t>& )> compute);
alexander3c798932021-03-26 21:42:19 +0000396
397 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100398 FeatureCalc<float>(TfLiteTensor* inputTensor,
399 size_t cacheSize,
400 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000401
402
403 static std::function<void (std::vector<int16_t>&, int, bool, size_t)>
404 GetFeatureCalculator(audio::DsCnnMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
405 {
406 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
407
408 TfLiteQuantization quant = inputTensor->quantization;
409
410 if (kTfLiteAffineQuantization == quant.type) {
411
412 auto *quantParams = (TfLiteAffineQuantization *) quant.params;
413 const float quantScale = quantParams->scale->data[0];
414 const int quantOffset = quantParams->zero_point->data[0];
415
416 switch (inputTensor->type) {
417 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100418 mfccFeatureCalc = FeatureCalc<int8_t>(inputTensor,
419 cacheSize,
420 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
421 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
422 quantScale,
423 quantOffset);
424 }
alexander3c798932021-03-26 21:42:19 +0000425 );
426 break;
427 }
428 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100429 mfccFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
430 cacheSize,
alexander3c798932021-03-26 21:42:19 +0000431 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
432 return mfcc.MfccComputeQuant<uint8_t>(audioDataWindow,
433 quantScale,
434 quantOffset);
435 }
436 );
437 break;
438 }
439 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100440 mfccFeatureCalc = FeatureCalc<int16_t>(inputTensor,
441 cacheSize,
442 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
443 return mfcc.MfccComputeQuant<int16_t>(audioDataWindow,
444 quantScale,
445 quantOffset);
446 }
alexander3c798932021-03-26 21:42:19 +0000447 );
448 break;
449 }
450 default:
451 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
452 }
453
454
455 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100456 mfccFeatureCalc = mfccFeatureCalc = FeatureCalc<float>(inputTensor,
457 cacheSize,
458 [&mfcc](std::vector<int16_t>& audioDataWindow) {
459 return mfcc.MfccCompute(audioDataWindow);
460 });
alexander3c798932021-03-26 21:42:19 +0000461 }
462 return mfccFeatureCalc;
463 }
464
465} /* namespace app */
466} /* namespace arm */