blob: 0179d6bc2d4d8364bbb279bf5d0c0cfdad296654 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Richard Burtoned35a6f2022-02-14 11:55:35 +00002 * Copyright (c) 2021-2022 Arm Limited. All rights reserved.
alexander3c798932021-03-26 21:42:19 +00003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
19#include "AdModel.hpp"
20#include "InputFiles.hpp"
21#include "Classifier.hpp"
22#include "hal.h"
23#include "AdMelSpectrogram.hpp"
24#include "AudioUtils.hpp"
Richard Burtoned35a6f2022-02-14 11:55:35 +000025#include "ImageUtils.hpp"
alexander3c798932021-03-26 21:42:19 +000026#include "UseCaseCommonUtils.hpp"
alexander31ae9f02022-02-10 16:15:54 +000027#include "log_macros.h"
Richard Burton4e002792022-05-04 09:45:02 +010028#include "AdProcessing.hpp"
alexander3c798932021-03-26 21:42:19 +000029
30namespace arm {
31namespace app {
32
33 /**
alexander3c798932021-03-26 21:42:19 +000034 * @brief Presents inference results using the data presentation
35 * object.
alexander3c798932021-03-26 21:42:19 +000036 * @param[in] result average sum of classification results
Isabella Gottardi56ee6202021-05-12 08:27:15 +010037 * @param[in] threshold if larger than this value we have an anomaly
alexander3c798932021-03-26 21:42:19 +000038 * @return true if successful, false otherwise
39 **/
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010040 static bool PresentInferenceResult(float result, float threshold);
alexander3c798932021-03-26 21:42:19 +000041
Richard Burton4e002792022-05-04 09:45:02 +010042 /** @brief Given a wav file name return AD model output index.
43 * @param[in] wavFileName Audio WAV filename.
44 * File name should be in format anything_goes_XX_here.wav
45 * where XX is the machine ID e.g. 00, 02, 04 or 06
46 * @return AD model output index as 8 bit integer.
47 **/
48 static int8_t OutputIndexFromFileName(std::string wavFileName);
alexander3c798932021-03-26 21:42:19 +000049
Richard Burton4e002792022-05-04 09:45:02 +010050 /* Anomaly Detection inference handler */
alexander3c798932021-03-26 21:42:19 +000051 bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
52 {
alexander3c798932021-03-26 21:42:19 +000053 constexpr uint32_t dataPsnTxtInfStartX = 20;
54 constexpr uint32_t dataPsnTxtInfStartY = 40;
55
alexander3c798932021-03-26 21:42:19 +000056 auto& model = ctx.Get<Model&>("model");
57
58 /* If the request has a valid size, set the audio index */
59 if (clipIndex < NUMBER_OF_FILES) {
Éanna Ó Catháin8f958872021-09-15 09:32:30 +010060 if (!SetAppCtxIfmIdx(ctx, clipIndex,"clipIndex")) {
alexander3c798932021-03-26 21:42:19 +000061 return false;
62 }
63 }
64 if (!model.IsInited()) {
65 printf_err("Model is not initialised! Terminating processing.\n");
66 return false;
67 }
68
Richard Burton4e002792022-05-04 09:45:02 +010069 auto& profiler = ctx.Get<Profiler&>("profiler");
70 const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength");
71 const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride");
alexander3c798932021-03-26 21:42:19 +000072 const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
Isabella Gottardi8df12f32021-04-07 17:15:31 +010073 const auto trainingMean = ctx.Get<float>("trainingMean");
alexander3c798932021-03-26 21:42:19 +000074 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
75
76 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
77 TfLiteTensor* inputTensor = model.GetInputTensor(0);
78
79 if (!inputTensor->dims) {
80 printf_err("Invalid input tensor dims\n");
81 return false;
82 }
83
Richard Burton4e002792022-05-04 09:45:02 +010084 AdPreProcess preProcess{
85 inputTensor,
86 melSpecFrameLength,
87 melSpecFrameStride,
88 trainingMean};
alexander3c798932021-03-26 21:42:19 +000089
Richard Burton4e002792022-05-04 09:45:02 +010090 AdPostProcess postProcess{outputTensor};
alexander3c798932021-03-26 21:42:19 +000091
92 do {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010093 hal_lcd_clear(COLOR_BLACK);
Richard Burton9b8d67a2021-12-10 12:32:51 +000094
alexander3c798932021-03-26 21:42:19 +000095 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
96
97 /* Get the output index to look at based on id in the filename. */
98 int8_t machineOutputIndex = OutputIndexFromFileName(get_filename(currentIndex));
99 if (machineOutputIndex == -1) {
100 return false;
101 }
102
alexander3c798932021-03-26 21:42:19 +0000103 /* Creating a sliding window through the whole audio clip. */
104 auto audioDataSlider = audio::SlidingWindow<const int16_t>(
Richard Burton4e002792022-05-04 09:45:02 +0100105 get_audio_array(currentIndex),
106 get_audio_array_size(currentIndex),
107 preProcess.GetAudioWindowSize(),
108 preProcess.GetAudioDataStride());
alexander3c798932021-03-26 21:42:19 +0000109
110 /* Result is an averaged sum over inferences. */
111 float result = 0;
112
113 /* Display message on the LCD - inference running. */
114 std::string str_inf{"Running inference... "};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100115 hal_lcd_display_text(
Richard Burton4e002792022-05-04 09:45:02 +0100116 str_inf.c_str(), str_inf.size(),
117 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
118
119 info("Running inference on audio clip %" PRIu32 " => %s\n",
120 currentIndex, get_filename(currentIndex));
alexander3c798932021-03-26 21:42:19 +0000121
122 /* Start sliding through audio clip. */
123 while (audioDataSlider.HasNext()) {
Richard Burton4e002792022-05-04 09:45:02 +0100124 const int16_t* inferenceWindow = audioDataSlider.Next();
alexander3c798932021-03-26 21:42:19 +0000125
Richard Burton4e002792022-05-04 09:45:02 +0100126 preProcess.SetAudioWindowIndex(audioDataSlider.Index());
127 preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize());
alexander3c798932021-03-26 21:42:19 +0000128
129 info("Inference %zu/%zu\n", audioDataSlider.Index() + 1,
130 audioDataSlider.TotalStrides() + 1);
131
132 /* Run inference over this audio clip sliding window */
alexander27b62d92021-05-04 20:46:08 +0100133 if (!RunInference(model, profiler)) {
134 return false;
135 }
alexander3c798932021-03-26 21:42:19 +0000136
Richard Burton4e002792022-05-04 09:45:02 +0100137 postProcess.DoPostProcess();
138 result += 0 - postProcess.GetOutputValue(machineOutputIndex);
alexander3c798932021-03-26 21:42:19 +0000139
140#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100141 DumpTensor(outputTensor);
alexander3c798932021-03-26 21:42:19 +0000142#endif /* VERIFY_TEST_OUTPUT */
143 } /* while (audioDataSlider.HasNext()) */
144
145 /* Use average over whole clip as final score. */
146 result /= (audioDataSlider.TotalStrides() + 1);
147
148 /* Erase. */
149 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100150 hal_lcd_display_text(
alexander3c798932021-03-26 21:42:19 +0000151 str_inf.c_str(), str_inf.size(),
152 dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
153
154 ctx.Set<float>("result", result);
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100155 if (!PresentInferenceResult(result, scoreThreshold)) {
alexander3c798932021-03-26 21:42:19 +0000156 return false;
157 }
158
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100159 profiler.PrintProfilingResult();
160
Éanna Ó Catháin8f958872021-09-15 09:32:30 +0100161 IncrementAppCtxIfmIdx(ctx,"clipIndex");
alexander3c798932021-03-26 21:42:19 +0000162
163 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
164
165 return true;
166 }
167
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100168 static bool PresentInferenceResult(float result, float threshold)
alexander3c798932021-03-26 21:42:19 +0000169 {
170 constexpr uint32_t dataPsnTxtStartX1 = 20;
171 constexpr uint32_t dataPsnTxtStartY1 = 30;
172 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment */
173
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100174 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000175
176 /* Display each result */
177 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
178
George Gekov93e59512021-08-03 11:18:41 +0100179 std::string anomalyScore = std::string{"Average anomaly score is: "} + std::to_string(result);
180 std::string anomalyThreshold = std::string("Anomaly threshold is: ") + std::to_string(threshold);
alexander3c798932021-03-26 21:42:19 +0000181
George Gekov93e59512021-08-03 11:18:41 +0100182 std::string anomalyResult;
alexander3c798932021-03-26 21:42:19 +0000183 if (result > threshold) {
George Gekov93e59512021-08-03 11:18:41 +0100184 anomalyResult += std::string("Anomaly detected!");
alexander3c798932021-03-26 21:42:19 +0000185 } else {
George Gekov93e59512021-08-03 11:18:41 +0100186 anomalyResult += std::string("Everything fine, no anomaly detected!");
alexander3c798932021-03-26 21:42:19 +0000187 }
188
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100189 hal_lcd_display_text(
George Gekov93e59512021-08-03 11:18:41 +0100190 anomalyScore.c_str(), anomalyScore.size(),
alexanderc350cdc2021-04-29 20:36:09 +0100191 dataPsnTxtStartX1, rowIdx1, false);
alexander3c798932021-03-26 21:42:19 +0000192
George Gekov93e59512021-08-03 11:18:41 +0100193 info("%s\n", anomalyScore.c_str());
194 info("%s\n", anomalyThreshold.c_str());
195 info("%s\n", anomalyResult.c_str());
alexander3c798932021-03-26 21:42:19 +0000196
197 return true;
198 }
199
Richard Burton4e002792022-05-04 09:45:02 +0100200 static int8_t OutputIndexFromFileName(std::string wavFileName)
alexander3c798932021-03-26 21:42:19 +0000201 {
Richard Burton4e002792022-05-04 09:45:02 +0100202 /* Filename is assumed in the form machine_id_00.wav */
203 std::string delimiter = "_"; /* First character used to split the file name up. */
204 size_t delimiterStart;
205 std::string subString;
206 size_t machineIdxInString = 3; /* Which part of the file name the machine id should be at. */
alexander3c798932021-03-26 21:42:19 +0000207
Richard Burton4e002792022-05-04 09:45:02 +0100208 for (size_t i = 0; i < machineIdxInString; ++i) {
209 delimiterStart = wavFileName.find(delimiter);
210 subString = wavFileName.substr(0, delimiterStart);
211 wavFileName.erase(0, delimiterStart + delimiter.length());
alexander3c798932021-03-26 21:42:19 +0000212 }
Richard Burton4e002792022-05-04 09:45:02 +0100213
214 /* At this point substring should be 00.wav */
215 delimiter = "."; /* Second character used to split the file name up. */
216 delimiterStart = subString.find(delimiter);
217 subString = (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
218
219 auto is_number = [](const std::string& str) -> bool
220 {
221 std::string::const_iterator it = str.begin();
222 while (it != str.end() && std::isdigit(*it)) ++it;
223 return !str.empty() && it == str.end();
224 };
225
226 const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
227
228 /* Return corresponding index in the output vector. */
229 if (machineIdx == 0) {
230 return 0;
231 } else if (machineIdx == 2) {
232 return 1;
233 } else if (machineIdx == 4) {
234 return 2;
235 } else if (machineIdx == 6) {
236 return 3;
237 } else {
238 printf_err("%d is an invalid machine index \n", machineIdx);
239 return -1;
240 }
alexander3c798932021-03-26 21:42:19 +0000241 }
242
243} /* namespace app */
244} /* namespace arm */