blob: 0c781793e92780f42a6f8050bb3db820e9392b86 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
2 * Copyright (c) 2021 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "AdModel.hpp"
20#include "InputFiles.hpp"
21#include "Classifier.hpp"
22#include "hal.h"
23#include "AdMelSpectrogram.hpp"
24#include "AudioUtils.hpp"
25#include "UseCaseCommonUtils.hpp"
26#include "AdPostProcessing.hpp"
27
28namespace arm {
29namespace app {
30
31 /**
32 * @brief Helper function to increment current audio clip index
Isabella Gottardi56ee6202021-05-12 08:27:15 +010033 * @param[in,out] ctx pointer to the application context object
alexander3c798932021-03-26 21:42:19 +000034 **/
alexanderc350cdc2021-04-29 20:36:09 +010035 static void IncrementAppCtxClipIdx(ApplicationContext& ctx);
alexander3c798932021-03-26 21:42:19 +000036
37 /**
38 * @brief Helper function to set the audio clip index
Isabella Gottardi56ee6202021-05-12 08:27:15 +010039 * @param[in,out] ctx pointer to the application context object
alexander3c798932021-03-26 21:42:19 +000040 * @param[in] idx value to be set
41 * @return true if index is set, false otherwise
42 **/
alexanderc350cdc2021-04-29 20:36:09 +010043 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx);
alexander3c798932021-03-26 21:42:19 +000044
45 /**
46 * @brief Presents inference results using the data presentation
47 * object.
48 * @param[in] platform reference to the hal platform object
49 * @param[in] result average sum of classification results
Isabella Gottardi56ee6202021-05-12 08:27:15 +010050 * @param[in] threshold if larger than this value we have an anomaly
alexander3c798932021-03-26 21:42:19 +000051 * @return true if successful, false otherwise
52 **/
alexanderc350cdc2021-04-29 20:36:09 +010053 static bool PresentInferenceResult(hal_platform& platform, float result, float threshold);
alexander3c798932021-03-26 21:42:19 +000054
55 /**
56 * @brief Returns a function to perform feature calculation and populates input tensor data with
57 * MelSpe data.
58 *
59 * Input tensor data type check is performed to choose correct MFCC feature data type.
60 * If tensor has an integer data type then original features are quantised.
61 *
62 * Warning: mfcc calculator provided as input must have the same life scope as returned function.
63 *
Isabella Gottardi56ee6202021-05-12 08:27:15 +010064 * @param[in] melSpec MFCC feature calculator.
65 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
66 * @param[in] cacheSize Size of the feture vectors cache (number of feature vectors).
67 * @param[in] trainingMean Training mean.
alexander3c798932021-03-26 21:42:19 +000068 * @return function function to be called providing audio sample and sliding window index.
69 */
70 static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
71 GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
72 TfLiteTensor* inputTensor,
73 size_t cacheSize,
74 float trainingMean);
75
76 /* Vibration classification handler */
77 bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
78 {
79 auto& platform = ctx.Get<hal_platform&>("platform");
Isabella Gottardi8df12f32021-04-07 17:15:31 +010080 auto& profiler = ctx.Get<Profiler&>("profiler");
alexander3c798932021-03-26 21:42:19 +000081
82 constexpr uint32_t dataPsnTxtInfStartX = 20;
83 constexpr uint32_t dataPsnTxtInfStartY = 40;
84
85 platform.data_psn->clear(COLOR_BLACK);
86
87 auto& model = ctx.Get<Model&>("model");
88
89 /* If the request has a valid size, set the audio index */
90 if (clipIndex < NUMBER_OF_FILES) {
alexanderc350cdc2021-04-29 20:36:09 +010091 if (!SetAppCtxClipIdx(ctx, clipIndex)) {
alexander3c798932021-03-26 21:42:19 +000092 return false;
93 }
94 }
95 if (!model.IsInited()) {
96 printf_err("Model is not initialised! Terminating processing.\n");
97 return false;
98 }
99
100 const auto frameLength = ctx.Get<int>("frameLength");
101 const auto frameStride = ctx.Get<int>("frameStride");
102 const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100103 const auto trainingMean = ctx.Get<float>("trainingMean");
alexander3c798932021-03-26 21:42:19 +0000104 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
105
106 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
107 TfLiteTensor* inputTensor = model.GetInputTensor(0);
108
109 if (!inputTensor->dims) {
110 printf_err("Invalid input tensor dims\n");
111 return false;
112 }
113
114 TfLiteIntArray* inputShape = model.GetInputShape(0);
115 const uint32_t kNumRows = inputShape->data[1];
116 const uint32_t kNumCols = inputShape->data[2];
117
118 audio::AdMelSpectrogram melSpec = audio::AdMelSpectrogram(frameLength);
119 melSpec.Init();
120
121 /* Deduce the data length required for 1 inference from the network parameters. */
122 const uint8_t inputResizeScale = 2;
123 const uint32_t audioDataWindowSize = (((inputResizeScale * kNumCols) - 1) * frameStride) + frameLength;
124
125 /* We are choosing to move by 20 frames across the audio for each inference. */
126 const uint8_t nMelSpecVectorsInAudioStride = 20;
127
128 auto audioDataStride = nMelSpecVectorsInAudioStride * frameStride;
129
130 do {
131 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
132
133 /* Get the output index to look at based on id in the filename. */
134 int8_t machineOutputIndex = OutputIndexFromFileName(get_filename(currentIndex));
135 if (machineOutputIndex == -1) {
136 return false;
137 }
138
139 /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
140 * "resizing" done here by multiplying stride by resize scale. */
141 auto audioMelSpecWindowSlider = audio::SlidingWindow<const int16_t>(
142 get_audio_array(currentIndex),
143 audioDataWindowSize, frameLength,
144 frameStride * inputResizeScale);
145
146 /* Creating a sliding window through the whole audio clip. */
147 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
148 get_audio_array(currentIndex),
149 get_audio_array_size(currentIndex),
150 audioDataWindowSize, audioDataStride);
151
152 /* Calculate number of the feature vectors in the window overlap region taking into account resizing.
153 * These feature vectors will be reused.*/
154 auto numberOfReusedFeatureVectors = kNumRows - (nMelSpecVectorsInAudioStride / inputResizeScale);
155
156 /* Construct feature calculation function. */
157 auto melSpecFeatureCalc = GetFeatureCalculator(melSpec, inputTensor,
158 numberOfReusedFeatureVectors, trainingMean);
159 if (!melSpecFeatureCalc){
160 return false;
161 }
162
163 /* Result is an averaged sum over inferences. */
164 float result = 0;
165
166 /* Display message on the LCD - inference running. */
167 std::string str_inf{"Running inference... "};
168 platform.data_psn->present_data_text(
169 str_inf.c_str(), str_inf.size(),
170 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100171 info("Running inference on audio clip %" PRIu32 " => %s\n", currentIndex, get_filename(currentIndex));
alexander3c798932021-03-26 21:42:19 +0000172
173 /* Start sliding through audio clip. */
174 while (audioDataSlider.HasNext()) {
175 const int16_t *inferenceWindow = audioDataSlider.Next();
176
177 /* We moved to the next window - set the features sliding to the new address. */
178 audioMelSpecWindowSlider.Reset(inferenceWindow);
179
180 /* The first window does not have cache ready. */
181 bool useCache = audioDataSlider.Index() > 0 && numberOfReusedFeatureVectors > 0;
182
183 /* Start calculating features inside one audio sliding window. */
184 while (audioMelSpecWindowSlider.HasNext()) {
185 const int16_t *melSpecWindow = audioMelSpecWindowSlider.Next();
186 std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(melSpecWindow,
187 melSpecWindow + frameLength);
188
189 /* Compute features for this window and write them to input tensor. */
190 melSpecFeatureCalc(melSpecAudioData, audioMelSpecWindowSlider.Index(),
191 useCache, nMelSpecVectorsInAudioStride, inputResizeScale);
192 }
193
194 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
195 audioDataSlider.TotalStrides() + 1);
196
197 /* Run inference over this audio clip sliding window */
alexander27b62d92021-05-04 20:46:08 +0100198 if (!RunInference(model, profiler)) {
199 return false;
200 }
alexander3c798932021-03-26 21:42:19 +0000201
202 /* Use the negative softmax score of the corresponding index as the outlier score */
203 std::vector<float> dequantOutput = Dequantize<int8_t>(outputTensor);
204 Softmax(dequantOutput);
205 result += -dequantOutput[machineOutputIndex];
206
207#if VERIFY_TEST_OUTPUT
208 arm::app::DumpTensor(outputTensor);
209#endif /* VERIFY_TEST_OUTPUT */
210 } /* while (audioDataSlider.HasNext()) */
211
212 /* Use average over whole clip as final score. */
213 result /= (audioDataSlider.TotalStrides() + 1);
214
215 /* Erase. */
216 str_inf = std::string(str_inf.size(), ' ');
217 platform.data_psn->present_data_text(
218 str_inf.c_str(), str_inf.size(),
219 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
220
221 ctx.Set<float>("result", result);
alexanderc350cdc2021-04-29 20:36:09 +0100222 if (!PresentInferenceResult(platform, result, scoreThreshold)) {
alexander3c798932021-03-26 21:42:19 +0000223 return false;
224 }
225
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100226 profiler.PrintProfilingResult();
227
alexanderc350cdc2021-04-29 20:36:09 +0100228 IncrementAppCtxClipIdx(ctx);
alexander3c798932021-03-26 21:42:19 +0000229
230 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
231
232 return true;
233 }
234
alexanderc350cdc2021-04-29 20:36:09 +0100235 static void IncrementAppCtxClipIdx(ApplicationContext& ctx)
alexander3c798932021-03-26 21:42:19 +0000236 {
237 auto curAudioIdx = ctx.Get<uint32_t>("clipIndex");
238
239 if (curAudioIdx + 1 >= NUMBER_OF_FILES) {
240 ctx.Set<uint32_t>("clipIndex", 0);
241 return;
242 }
243 ++curAudioIdx;
244 ctx.Set<uint32_t>("clipIndex", curAudioIdx);
245 }
246
alexanderc350cdc2021-04-29 20:36:09 +0100247 static bool SetAppCtxClipIdx(ApplicationContext& ctx, uint32_t idx)
alexander3c798932021-03-26 21:42:19 +0000248 {
249 if (idx >= NUMBER_OF_FILES) {
Kshitij Sisodiaf9c19ea2021-05-07 16:08:14 +0100250 printf_err("Invalid idx %" PRIu32 " (expected less than %u)\n",
alexander3c798932021-03-26 21:42:19 +0000251 idx, NUMBER_OF_FILES);
252 return false;
253 }
254 ctx.Set<uint32_t>("clipIndex", idx);
255 return true;
256 }
257
alexanderc350cdc2021-04-29 20:36:09 +0100258 static bool PresentInferenceResult(hal_platform& platform, float result, float threshold)
alexander3c798932021-03-26 21:42:19 +0000259 {
260 constexpr uint32_t dataPsnTxtStartX1 = 20;
261 constexpr uint32_t dataPsnTxtStartY1 = 30;
262 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment */
263
264 platform.data_psn->set_text_color(COLOR_GREEN);
265
266 /* Display each result */
267 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
268
George Gekov93e59512021-08-03 11:18:41 +0100269 std::string anomalyScore = std::string{"Average anomaly score is: "} + std::to_string(result);
270 std::string anomalyThreshold = std::string("Anomaly threshold is: ") + std::to_string(threshold);
alexander3c798932021-03-26 21:42:19 +0000271
George Gekov93e59512021-08-03 11:18:41 +0100272 std::string anomalyResult;
alexander3c798932021-03-26 21:42:19 +0000273 if (result > threshold) {
George Gekov93e59512021-08-03 11:18:41 +0100274 anomalyResult += std::string("Anomaly detected!");
alexander3c798932021-03-26 21:42:19 +0000275 } else {
George Gekov93e59512021-08-03 11:18:41 +0100276 anomalyResult += std::string("Everything fine, no anomaly detected!");
alexander3c798932021-03-26 21:42:19 +0000277 }
278
279 platform.data_psn->present_data_text(
George Gekov93e59512021-08-03 11:18:41 +0100280 anomalyScore.c_str(), anomalyScore.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100281 dataPsnTxtStartX1, rowIdx1, false);
alexander3c798932021-03-26 21:42:19 +0000282
George Gekov93e59512021-08-03 11:18:41 +0100283 info("%s\n", anomalyScore.c_str());
284 info("%s\n", anomalyThreshold.c_str());
285 info("%s\n", anomalyResult.c_str());
alexander3c798932021-03-26 21:42:19 +0000286
287 return true;
288 }
289
290 /**
291 * @brief Generic feature calculator factory.
292 *
293 * Returns lambda function to compute features using features cache.
294 * Real features math is done by a lambda function provided as a parameter.
295 * Features are written to input tensor memory.
296 *
297 * @tparam T feature vector type.
298 * @param inputTensor model input tensor pointer.
299 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
300 * @param compute features calculator function.
301 * @return lambda function to compute features.
302 */
303 template<class T>
304 std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100305 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
306 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
alexander3c798932021-03-26 21:42:19 +0000307 {
308 /* Feature cache to be captured by lambda function*/
309 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
310
311 return [=](std::vector<int16_t>& audioDataWindow,
312 size_t index,
313 bool useCache,
314 size_t featuresOverlapIndex,
315 size_t resizeScale)
316 {
317 T *tensorData = tflite::GetTensorData<T>(inputTensor);
318 std::vector<T> features;
319
320 /* Reuse features from cache if cache is ready and sliding windows overlap.
321 * Overlap is in the beginning of sliding window with a size of a feature cache. */
322 if (useCache && index < featureCache.size()) {
323 features = std::move(featureCache[index]);
324 } else {
325 features = std::move(compute(audioDataWindow));
326 }
327 auto size = features.size() / resizeScale;
328 auto sizeBytes = sizeof(T);
329
330 /* Input should be transposed and "resized" by skipping elements. */
331 for (size_t outIndex = 0; outIndex < size; outIndex++) {
332 std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
333 }
334
335 /* Start renewing cache as soon iteration goes out of the windows overlap. */
336 if (index >= featuresOverlapIndex / resizeScale) {
337 featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
338 }
339 };
340 }
341
342 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100343 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
344 size_t cacheSize,
345 std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
346
347 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
348 FeatureCalc<uint8_t>(TfLiteTensor* inputTensor,
alexander3c798932021-03-26 21:42:19 +0000349 size_t cacheSize,
alexanderc350cdc2021-04-29 20:36:09 +0100350 std::function<std::vector<uint8_t> (std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000351
352 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100353 FeatureCalc<int16_t>(TfLiteTensor* inputTensor,
354 size_t cacheSize,
355 std::function<std::vector<int16_t> (std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000356
357 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
alexanderc350cdc2021-04-29 20:36:09 +0100358 FeatureCalc<float>(TfLiteTensor *inputTensor,
359 size_t cacheSize,
360 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
alexander3c798932021-03-26 21:42:19 +0000361
362
363 static std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
364 GetFeatureCalculator(audio::AdMelSpectrogram& melSpec, TfLiteTensor* inputTensor, size_t cacheSize, float trainingMean)
365 {
366 std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc;
367
368 TfLiteQuantization quant = inputTensor->quantization;
369
370 if (kTfLiteAffineQuantization == quant.type) {
371
372 auto *quantParams = (TfLiteAffineQuantization *) quant.params;
373 const float quantScale = quantParams->scale->data[0];
374 const int quantOffset = quantParams->zero_point->data[0];
375
376 switch (inputTensor->type) {
377 case kTfLiteInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100378 melSpecFeatureCalc = FeatureCalc<int8_t>(inputTensor,
379 cacheSize,
380 [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
381 return melSpec.MelSpecComputeQuant<int8_t>(
382 audioDataWindow,
383 quantScale,
384 quantOffset,
385 trainingMean);
386 }
alexander3c798932021-03-26 21:42:19 +0000387 );
388 break;
389 }
390 case kTfLiteUInt8: {
alexanderc350cdc2021-04-29 20:36:09 +0100391 melSpecFeatureCalc = FeatureCalc<uint8_t>(inputTensor,
392 cacheSize,
393 [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
394 return melSpec.MelSpecComputeQuant<uint8_t>(
395 audioDataWindow,
396 quantScale,
397 quantOffset,
398 trainingMean);
399 }
alexander3c798932021-03-26 21:42:19 +0000400 );
401 break;
402 }
403 case kTfLiteInt16: {
alexanderc350cdc2021-04-29 20:36:09 +0100404 melSpecFeatureCalc = FeatureCalc<int16_t>(inputTensor,
405 cacheSize,
406 [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
407 return melSpec.MelSpecComputeQuant<int16_t>(
408 audioDataWindow,
409 quantScale,
410 quantOffset,
411 trainingMean);
412 }
alexander3c798932021-03-26 21:42:19 +0000413 );
414 break;
415 }
416 default:
417 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
418 }
419
420
421 } else {
alexanderc350cdc2021-04-29 20:36:09 +0100422 melSpecFeatureCalc = melSpecFeatureCalc = FeatureCalc<float>(inputTensor,
423 cacheSize,
424 [=, &melSpec](
425 std::vector<int16_t>& audioDataWindow) {
426 return melSpec.ComputeMelSpec(
427 audioDataWindow,
428 trainingMean);
429 });
alexander3c798932021-03-26 21:42:19 +0000430 }
431 return melSpecFeatureCalc;
432 }
433
434} /* namespace app */
435} /* namespace arm */