blob: ec35156ea6d81599f76985bd5a1fadaf9a7a0c3f [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "AdModel.hpp"
20#include "InputFiles.hpp"
21#include "Classifier.hpp"
22#include "hal.h"
23#include "AdMelSpectrogram.hpp"
24#include "AudioUtils.hpp"
25#include "UseCaseCommonUtils.hpp"
26#include "AdPostProcessing.hpp"
27
28namespace arm {
29namespace app {
30
31 /**
32 * @brief Helper function to increment current audio clip index
Isabella Gottardi56ee6202021-05-12 08:27:15 +010033 * @param[in,out] ctx pointer to the application context object
alexander3c798932021-03-26 21:42:19 +000034 **/
alexanderc350cdc2021-04-29 20:36:09 +010035 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
alexander3c798932021-03-26 21:42:19 +000036
37 /**
38 * @brief Helper function to set the audio clip index
Isabella Gottardi56ee6202021-05-12 08:27:15 +010039 * @param[in,out] ctx pointer to the application context object
alexander3c798932021-03-26 21:42:19 +000040 * @param[in] idx value to be set
41 * @return true if index is set, false otherwise
42 **/
alexanderc350cdc2021-04-29 20:36:09 +010043 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
alexander3c798932021-03-26 21:42:19 +000044
45 /**
46 * @brief Presents inference results using the data presentation
47 * object.
48 * @param[in] platform reference to the hal platform object
49 * @param[in] result average sum of classification results
Isabella Gottardi56ee6202021-05-12 08:27:15 +010050 * @param[in] threshold if larger than this value we have an anomaly
alexander3c798932021-03-26 21:42:19 +000051 * @return true if successful, false otherwise
52 **/
alexanderc350cdc2021-04-29 20:36:09 +010053 static bool PresentInferenceResult(hal_platform& platform, float result, float threshold);
alexander3c798932021-03-26 21:42:19 +000054
55 /**
56 * @brief Returns a function to perform feature calculation and populates input tensor data with
57 * MelSpe data.
58 *
59 * Input tensor data type check is performed to choose correct MFCC feature data type.
60 * If tensor has an integer data type then original features are quantised.
61 *
62 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
63 *
Isabella Gottardi56ee6202021-05-12 08:27:15 +010064 * @param[in] melSpec MFCC feature calculator.
65 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
66 * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
67 * @param[in] trainingMean Training mean.
alexander3c798932021-03-26 21:42:19 +000068 * @return function function to be called providing audio sample and sliding window index.
69 */
70 static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
71 GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
72 TfLiteTensor* inputTensor,
73 size_t cacheSize,
74 float trainingMean);
75
76 /* Vibration classification handler */
77 bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
78 {
79 auto& platform = ctx.Get<hal_platform&>("platform");
Isabella Gottardi8df12f32021-04-07 17:15:31 +010080 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +000081
82 constexpr uint32_t dataPsnTxtInfStartX = 20;
83 constexpr uint32_t dataPsnTxtInfStartY = 40;
84
85 platform.data_psn->clear(COLOR_BLACK);
86
87 auto& model = ctx.Get<Model&>("model");
88
89 /* If the request has a valid size, set the audio index */
90 if (clipIndex < NUMBER_OF_FILES) {
alexanderc350cdc2021-04-29 20:36:09 +010091 if (!SetAppCtxClipIdx(ctx, clipIndex)) {
alexander3c798932021-03-26 21:42:19 +000092 return false;
93 }
94 }
95 if (!model.IsInited()) {
96 printf_err("Model is not initialised! Terminating processing.\n");
97 return false;
98 }
99
100 const auto frameLength = ctx.Get<int>("frameLength");
101 const auto frameStride = ctx.Get<int>("frameStride");
102 const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100103 const auto trainingMean = ctx.Get<float>("trainingMean");
alexander3c798932021-03-26 21:42:19 +0000104 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
105
106 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
107 TfLiteTensor* inputTensor = model.GetInputTensor(0);
108
109 if (!inputTensor->dims) {
110 printf_err("Invalid input tensor dims\n");
111 return false;
112 }
113
114 TfLiteIntArray* inputShape = model.GetInputShape(0);
115 const uint32_t kNumRows = inputShape->data[1];
116 const uint32_t kNumCols = inputShape->data[2];
117
118 audio::AdMelSpectrogram melSpec = audio::AdMelSpectrogram(frameLength);
119 melSpec.Init();
120
121 /* Deduce the data length required for 1 inference from the network parameters. */
122 const uint8_t inputResizeScale = 2;
123 const uint32_t audioDataWindowSize = (((inputResizeScale * kNumCols) - 1) * frameStride) + frameLength;
124
125 /* We are choosing to move by 20 frames across the audio for each inference. */
126 const uint8_t nMelSpecVectorsInAudioStride = 20;
127
128 auto audioDataStride = nMelSpecVectorsInAudioStride * frameStride;
129
130 do {
131 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
132
133 /* Get the output index to look at based on id in the filename. */
134 int8_t machineOutputIndex = OutputIndexFromFileName(get_filename(currentIndex));
135 if (machineOutputIndex == -1) {
136 return false;
137 }
138
139 /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
140 * "resizing" done here by multiplying stride by resize scale. */
141 auto audioMelSpecWindowSlider = audio::SlidingWindow<const int16_t>(
142 get_audio_array(currentIndex),
143 audioDataWindowSize, frameLength,
144 frameStride * inputResizeScale);
145
146 /* Creating a sliding window through the whole audio clip. */
147 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
148 get_audio_array(currentIndex),
149 get_audio_array_size(currentIndex),
150 audioDataWindowSize, audioDataStride);
151
152 /* Calculate number of the feature vectors in the window overlap region taking into account resizing.
153 * These feature vectors will be reused.*/
154 auto numberOfReusedFeatureVectors = kNumRows - (nMelSpecVectorsInAudioStride / inputResizeScale);
155
156 /* Construct feature calculation function. */
157 auto melSpecFeatureCalc = GetFeatureCalculator(melSpec, inputTensor,
158 numberOfReusedFeatureVectors, trainingMean);
159 if (!melSpecFeatureCalc){
160 return false;
161 }
162
163 /* Result is an averaged sum over inferences. */
164 float result = 0;
165
166 /* Display message on the LCD - inference running. */
167 std::string str_inf{"Running inference... "};
168 platform.data_psn->present_data_text(
169 str_inf.c_str(), str_inf.size(),
170 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100171 info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex));
alexander3c798932021-03-26 21:42:19 +0000172
173 /* Start sliding through audio clip. */
174 while (audioDataSlider.HasNext()) {
175 const int16_t *inferenceWindow = audioDataSlider.Next();
176
177 /* We moved to the next window - set the features sliding to the new address. */
178 audioMelSpecWindowSlider.Reset(inferenceWindow);
179
180 /* The first window does not have cache ready. */
181 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
182
183 /* Start calculating features inside one audio sliding window. */
184 while (audioMelSpecWindowSlider.HasNext()) {
185 const int16_t *melSpecWindow = audioMelSpecWindowSlider.Next();
186 std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(melSpecWindow,
187 melSpecWindow + frameLength);
188
189 /* Compute features for this window and write them to input tensor. */
190 melSpecFeatureCalc(melSpecAudioData, audioMelSpecWindowSlider.Index(),
191 useCache, nMelSpecVectorsInAudioStride, inputResizeScale);
192 }
193
194 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
195 audioDataSlider.TotalStrides() + 1);
196
197 /* Run inference over this audio clip sliding window */
alexander27b62d92021-05-04 20:46:08 +0100198 if (!RunInference(model, profiler)) {
199 return false;
200 }
alexander3c798932021-03-26 21:42:19 +0000201
202 /* Use the negative softmax score of the corresponding index as the outlier score */
203 std::vector<float> dequantOutput = Dequantize<int8_t>(outputTensor);
204 Softmax(dequantOutput);
205 result += -dequantOutput[machineOutputIndex];
206
207#if VERIFY_TEST_OUTPUT
208 arm::app::DumpTensor(outputTensor);
209#endif /* VERIFY_TEST_OUTPUT */
210 } /* while (audioDataSlider.HasNext()) */
211
212 /* Use average over whole clip as final score. */
213 result /= (audioDataSlider.TotalStrides() + 1);
214
215 /* Erase. */
216 str_inf = std::string(str_inf.size(), ' ');
217 platform.data_psn->present_data_text(
218 str_inf.c_str(), str_inf.size(),
219 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
220
221 ctx.Set<float>("result", result);
alexanderc350cdc2021-04-29 20:36:09 +0100222 if (!PresentInferenceResult(platform, result, scoreThreshold)) {
alexander3c798932021-03-26 21:42:19 +0000223 return false;
224 }
225
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100226 profiler.PrintProfilingResult();
227
alexanderc350cdc2021-04-29 20:36:09 +0100228 IncrementAppCtxClipIdx(ctx);
alexander3c798932021-03-26 21:42:19 +0000229
230 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
231
232 return true;
233 }
234
alexanderc350cdc2021-04-29 20:36:09 +0100235 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
alexander3c798932021-03-26 21:42:19 +0000236 {
237 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
238
239 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
240 ctx.Set<uint32_t>("clipIndex", 0);
241 return;
242 }
243 ++curAudioIdx;
244 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
245 }
246
alexanderc350cdc2021-04-29 20:36:09 +0100247 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx)
alexander3c798932021-03-26 21:42:19 +0000248 {
249 if (idx >= NUMBER_OF_FILES) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100250 printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n",
alexander3c798932021-03-26 21:42:19 +0000251 idx, NUMBER_OF_FILES);
252 return false;
253 }
254 ctx.Set<uint32_t>("clipIndex", idx);
255 return true;
256 }
257
alexanderc350cdc2021-04-29 20:36:09 +0100258 static bool PresentInferenceResult(hal_platform& platform, float result, float threshold)
alexander3c798932021-03-26 21:42:19 +0000259 {
260 constexpr uint32_t dataPsnTxtStartX1 = 20;
261 constexpr uint32_t dataPsnTxtStartY1 = 30;
262 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment */
263
264 platform.data_psn->set_text_color(COLOR_GREEN);
265
266 /* Display each result */
267 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
268
269 std::string resultStr = std::string{"Average anomaly score is: "} + std::to_string(result) +
270 std::string("\n") + std::string("Anomaly threshold is: ") + std::to_string(threshold) +
271 std::string("\n");
272
273 if (result > threshold) {
274 resultStr += std::string("Anomaly detected!");
275 } else {
276 resultStr += std::string("Everything fine, no anomaly detected!");
277 }
278
279 platform.data_psn->present_data_text(
280 resultStr.c_str(), resultStr.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100281 dataPsnTxtStartX1, rowIdx1, false);
alexander3c798932021-03-26 21:42:19 +0000282
283 info("%s\n", resultStr.c_str());
284
285 return true;
286 }
287
288 /**
289 * @brief Generic feature calculator factory.
290 *
291 * Returns lambda function to compute features using features cache.
292 * Real features math is done by a lambda function provided as a parameter.
293 * Features are written to input tensor memory.
294 *
295 * @tparam T feature vector type.
296 * @param inputTensor model input tensor pointer.
297 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
298 * @param compute features calculator function.
299 * @return lambda function to compute features.
300 */
301 template<class T>
302 std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100303 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
304 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000305 {
306 /* Feature cache to be captured by lambda function*/
307 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
308
309 return [=](std::vector<int16_t>& audioDataWindow,
310 size_t index,
311 bool useCache,
312 size_t featuresOverlapIndex,
313 size_t resizeScale)
314 {
315 T *tensorData = tflite::GetTensorData<T>(inputTensor);
316 std::vector<T> features;
317
318 /* Reuse features from cache if cache is ready and sliding windows overlap.
319 * Overlap is in the beginning of sliding window with a size of a feature cache. */
320 if (useCache && index < featureCache.size()) {
321 features = std::move(featureCache[index]);
322 } else {
323 features = std::move(compute(audioDataWindow));
324 }
325 auto size = features.size() / resizeScale;
326 auto sizeBytes = sizeof(T);
327
328 /* Input should be transposed and "resized" by skipping elements. */
329 for (size_t outIndex = 0; outIndex < size; outIndex++) {
330 std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
331 }
332
333 /* Start renewing cache as soon iteration goes out of the windows overlap. */
334 if (index >= featuresOverlapIndex / resizeScale) {
335 featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
336 }
337 };
338 }
339
340 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100341 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
342 size_t cacheSize,
343 std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
344
345 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
346 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000347 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100348 std::function<std::vector<uint8_t> (std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000349
350 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100351 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
352 size_t cacheSize,
353 std::function<std::vector<int16_t> (std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000354
355 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100356 FeatureCalc<float>(TfLiteTensor *inputTensor,
357 size_t cacheSize,
358 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000359
360
361 static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
362 GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, TfLiteTensor* inputTensor, size_t cacheSize, float trainingMean)
363 {
364 std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc;
365
366 TfLiteQuantization quant = inputTensor->quantization;
367
368 if (kTfLiteAffineQuantization == quant.type) {
369
370 auto *quantParams = (TfLiteAffineQuantization *) quant.params;
371 const float quantScale = quantParams->scale->data[0];
372 const int quantOffset = quantParams->zero_point->data[0];
373
374 switch (inputTensor->type) {
375 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100376 melSpecFeatureCalc = FeatureCalc<int8_t>(inputTensor,
377 cacheSize,
378 [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
379 return melSpec.MelSpecComputeQuant<int8_t>(
380 audioDataWindow,
381 quantScale,
382 quantOffset,
383 trainingMean);
384 }
alexander3c798932021-03-26 21:42:19 +0000385 );
386 break;
387 }
388 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100389 melSpecFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
390 cacheSize,
391 [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
392 return melSpec.MelSpecComputeQuant<uint8_t>(
393 audioDataWindow,
394 quantScale,
395 quantOffset,
396 trainingMean);
397 }
alexander3c798932021-03-26 21:42:19 +0000398 );
399 break;
400 }
401 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100402 melSpecFeatureCalc = FeatureCalc<int16_t>(inputTensor,
403 cacheSize,
404 [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
405 return melSpec.MelSpecComputeQuant<int16_t>(
406 audioDataWindow,
407 quantScale,
408 quantOffset,
409 trainingMean);
410 }
alexander3c798932021-03-26 21:42:19 +0000411 );
412 break;
413 }
414 default:
415 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
416 }
417
418
419 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100420 melSpecFeatureCalc = melSpecFeatureCalc = FeatureCalc<float>(inputTensor,
421 cacheSize,
422 [=, &melSpec](
423 std::vector<int16_t>& audioDataWindow) {
424 return melSpec.ComputeMelSpec(
425 audioDataWindow,
426 trainingMean);
427 });
alexander3c798932021-03-26 21:42:19 +0000428 }
429 return melSpecFeatureCalc;
430 }
431
432} /* namespace app */
433} /* namespace arm */