blob: c71fdebc3f8d6caebfea79f210226a7a1c655703 [file] [log] [blame]
alexander3c798932021-03-26 21:42:19 +00001/*
Kshitij Sisodia2ea46232022-12-19 16:37:33 +00002 * SPDX-FileCopyrightText: Copyright 2021-2022 Arm Limited and/or its affiliates
3 * <open-source-office@arm.com> SPDX-License-Identifier: Apache-2.0
alexander3c798932021-03-26 21:42:19 +00004 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "UseCaseHandler.hpp"
18
alexander3c798932021-03-26 21:42:19 +000019#include "AdMelSpectrogram.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000020#include "AdModel.hpp"
Richard Burton4e002792022-05-04 09:45:02 +010021#include "AdProcessing.hpp"
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000022#include "AudioUtils.hpp"
23#include "Classifier.hpp"
24#include "ImageUtils.hpp"
25#include "InputFiles.hpp"
26#include "UseCaseCommonUtils.hpp"
27#include "hal.h"
28#include "log_macros.h"
alexander3c798932021-03-26 21:42:19 +000029
30namespace arm {
31namespace app {
32
33 /**
alexander3c798932021-03-26 21:42:19 +000034 * @brief Presents inference results using the data presentation
35 * object.
alexander3c798932021-03-26 21:42:19 +000036 * @param[in] result average sum of classification results
Isabella Gottardi56ee6202021-05-12 08:27:15 +010037 * @param[in] threshold if larger than this value we have an anomaly
alexander3c798932021-03-26 21:42:19 +000038 * @return true if successful, false otherwise
39 **/
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010040 static bool PresentInferenceResult(float result, float threshold);
alexander3c798932021-03-26 21:42:19 +000041
Richard Burton4e002792022-05-04 09:45:02 +010042 /** @brief Given a wav file name return AD model output index.
43 * @param[in] wavFileName Audio WAV filename.
44 * File name should be in format anything_goes_XX_here.wav
45 * where XX is the machine ID e.g. 00, 02, 04 or 06
46 * @return AD model output index as 8 bit integer.
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000047 **/
Richard Burton4e002792022-05-04 09:45:02 +010048 static int8_t OutputIndexFromFileName(std::string wavFileName);
alexander3c798932021-03-26 21:42:19 +000049
Richard Burton4e002792022-05-04 09:45:02 +010050 /* Anomaly Detection inference handler */
alexander3c798932021-03-26 21:42:19 +000051 bool ClassifyVibrationHandler(ApplicationContext& ctx, uint32_t clipIndex, bool runAll)
52 {
alexander3c798932021-03-26 21:42:19 +000053 constexpr uint32_t dataPsnTxtInfStartX = 20;
54 constexpr uint32_t dataPsnTxtInfStartY = 40;
55
alexander3c798932021-03-26 21:42:19 +000056 auto& model = ctx.Get<Model&>("model");
57
58 /* If the request has a valid size, set the audio index */
59 if (clipIndex < NUMBER_OF_FILES) {
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000060 if (!SetAppCtxIfmIdx(ctx, clipIndex, "clipIndex")) {
alexander3c798932021-03-26 21:42:19 +000061 return false;
62 }
63 }
64 if (!model.IsInited()) {
65 printf_err("Model is not initialised! Terminating processing.\n");
66 return false;
67 }
68
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000069 auto& profiler = ctx.Get<Profiler&>("profiler");
Richard Burton4e002792022-05-04 09:45:02 +010070 const auto melSpecFrameLength = ctx.Get<uint32_t>("frameLength");
71 const auto melSpecFrameStride = ctx.Get<uint32_t>("frameStride");
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000072 const auto scoreThreshold = ctx.Get<float>("scoreThreshold");
73 const auto trainingMean = ctx.Get<float>("trainingMean");
74 auto startClipIdx = ctx.Get<uint32_t>("clipIndex");
alexander3c798932021-03-26 21:42:19 +000075
76 TfLiteTensor* outputTensor = model.GetOutputTensor(0);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000077 TfLiteTensor* inputTensor = model.GetInputTensor(0);
alexander3c798932021-03-26 21:42:19 +000078
79 if (!inputTensor->dims) {
80 printf_err("Invalid input tensor dims\n");
81 return false;
82 }
83
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000084 AdPreProcess preProcess{inputTensor, melSpecFrameLength, melSpecFrameStride, trainingMean};
alexander3c798932021-03-26 21:42:19 +000085
Richard Burton4e002792022-05-04 09:45:02 +010086 AdPostProcess postProcess{outputTensor};
alexander3c798932021-03-26 21:42:19 +000087
88 do {
Kshitij Sisodia68fdd112022-04-06 13:03:20 +010089 hal_lcd_clear(COLOR_BLACK);
Richard Burton9b8d67a2021-12-10 12:32:51 +000090
alexander3c798932021-03-26 21:42:19 +000091 auto currentIndex = ctx.Get<uint32_t>("clipIndex");
92
93 /* Get the output index to look at based on id in the filename. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +000094 int8_t machineOutputIndex = OutputIndexFromFileName(GetFilename(currentIndex));
alexander3c798932021-03-26 21:42:19 +000095 if (machineOutputIndex == -1) {
96 return false;
97 }
98
alexander3c798932021-03-26 21:42:19 +000099 /* Creating a sliding window through the whole audio clip. */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000100 auto audioDataSlider =
101 audio::SlidingWindow<const int16_t>(GetAudioArray(currentIndex),
102 GetAudioArraySize(currentIndex),
103 preProcess.GetAudioWindowSize(),
104 preProcess.GetAudioDataStride());
alexander3c798932021-03-26 21:42:19 +0000105
106 /* Result is an averaged sum over inferences. */
107 float result = 0;
108
109 /* Display message on the LCD - inference running. */
110 std::string str_inf{"Running inference... "};
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100111 hal_lcd_display_text(
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000112 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
Richard Burton4e002792022-05-04 09:45:02 +0100113
114 info("Running inference on audio clip %" PRIu32 " => %s\n",
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000115 currentIndex,
116 GetFilename(currentIndex));
alexander3c798932021-03-26 21:42:19 +0000117
118 /* Start sliding through audio clip. */
119 while (audioDataSlider.HasNext()) {
Richard Burton4e002792022-05-04 09:45:02 +0100120 const int16_t* inferenceWindow = audioDataSlider.Next();
alexander3c798932021-03-26 21:42:19 +0000121
Richard Burton4e002792022-05-04 09:45:02 +0100122 preProcess.SetAudioWindowIndex(audioDataSlider.Index());
123 preProcess.DoPreProcess(inferenceWindow, preProcess.GetAudioWindowSize());
alexander3c798932021-03-26 21:42:19 +0000124
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000125 info("Inference %zu/%zu\n",
126 audioDataSlider.Index() + 1,
alexander3c798932021-03-26 21:42:19 +0000127 audioDataSlider.TotalStrides() + 1);
128
129 /* Run inference over this audio clip sliding window */
alexander27b62d92021-05-04 20:46:08 +0100130 if (!RunInference(model, profiler)) {
131 return false;
132 }
alexander3c798932021-03-26 21:42:19 +0000133
Richard Burton4e002792022-05-04 09:45:02 +0100134 postProcess.DoPostProcess();
135 result += 0 - postProcess.GetOutputValue(machineOutputIndex);
alexander3c798932021-03-26 21:42:19 +0000136
137#if VERIFY_TEST_OUTPUT
Richard Burton4e002792022-05-04 09:45:02 +0100138 DumpTensor(outputTensor);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000139#endif /* VERIFY_TEST_OUTPUT */
alexander3c798932021-03-26 21:42:19 +0000140 } /* while (audioDataSlider.HasNext()) */
141
142 /* Use average over whole clip as final score. */
143 result /= (audioDataSlider.TotalStrides() + 1);
144
145 /* Erase. */
146 str_inf = std::string(str_inf.size(), ' ');
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100147 hal_lcd_display_text(
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000148 str_inf.c_str(), str_inf.size(), dataPsnTxtInfStartX, dataPsnTxtInfStartY, 0);
alexander3c798932021-03-26 21:42:19 +0000149
150 ctx.Set<float>("result", result);
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100151 if (!PresentInferenceResult(result, scoreThreshold)) {
alexander3c798932021-03-26 21:42:19 +0000152 return false;
153 }
154
Isabella Gottardi8df12f32021-04-07 17:15:31 +0100155 profiler.PrintProfilingResult();
156
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000157 IncrementAppCtxIfmIdx(ctx, "clipIndex");
alexander3c798932021-03-26 21:42:19 +0000158
159 } while (runAll && ctx.Get<uint32_t>("clipIndex") != startClipIdx);
160
161 return true;
162 }
163
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100164 static bool PresentInferenceResult(float result, float threshold)
alexander3c798932021-03-26 21:42:19 +0000165 {
166 constexpr uint32_t dataPsnTxtStartX1 = 20;
167 constexpr uint32_t dataPsnTxtStartY1 = 30;
168 constexpr uint32_t dataPsnTxtYIncr = 16; /* Row index increment */
169
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100170 hal_lcd_set_text_color(COLOR_GREEN);
alexander3c798932021-03-26 21:42:19 +0000171
172 /* Display each result */
173 uint32_t rowIdx1 = dataPsnTxtStartY1 + 2 * dataPsnTxtYIncr;
174
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000175 std::string anomalyScore =
176 std::string{"Average anomaly score is: "} + std::to_string(result);
177 std::string anomalyThreshold =
178 std::string("Anomaly threshold is: ") + std::to_string(threshold);
alexander3c798932021-03-26 21:42:19 +0000179
George Gekov93e59512021-08-03 11:18:41 +0100180 std::string anomalyResult;
alexander3c798932021-03-26 21:42:19 +0000181 if (result > threshold) {
George Gekov93e59512021-08-03 11:18:41 +0100182 anomalyResult += std::string("Anomaly detected!");
alexander3c798932021-03-26 21:42:19 +0000183 } else {
George Gekov93e59512021-08-03 11:18:41 +0100184 anomalyResult += std::string("Everything fine, no anomaly detected!");
alexander3c798932021-03-26 21:42:19 +0000185 }
186
Kshitij Sisodia68fdd112022-04-06 13:03:20 +0100187 hal_lcd_display_text(
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000188 anomalyScore.c_str(), anomalyScore.size(), dataPsnTxtStartX1, rowIdx1, false);
alexander3c798932021-03-26 21:42:19 +0000189
George Gekov93e59512021-08-03 11:18:41 +0100190 info("%s\n", anomalyScore.c_str());
191 info("%s\n", anomalyThreshold.c_str());
192 info("%s\n", anomalyResult.c_str());
alexander3c798932021-03-26 21:42:19 +0000193
194 return true;
195 }
196
Richard Burton4e002792022-05-04 09:45:02 +0100197 static int8_t OutputIndexFromFileName(std::string wavFileName)
alexander3c798932021-03-26 21:42:19 +0000198 {
Richard Burton4e002792022-05-04 09:45:02 +0100199 /* Filename is assumed in the form machine_id_00.wav */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000200 std::string delimiter = "_"; /* First character used to split the file name up. */
Richard Burton4e002792022-05-04 09:45:02 +0100201 size_t delimiterStart;
202 std::string subString;
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000203 size_t machineIdxInString =
204 3; /* Which part of the file name the machine id should be at. */
alexander3c798932021-03-26 21:42:19 +0000205
Richard Burton4e002792022-05-04 09:45:02 +0100206 for (size_t i = 0; i < machineIdxInString; ++i) {
207 delimiterStart = wavFileName.find(delimiter);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000208 subString = wavFileName.substr(0, delimiterStart);
Richard Burton4e002792022-05-04 09:45:02 +0100209 wavFileName.erase(0, delimiterStart + delimiter.length());
alexander3c798932021-03-26 21:42:19 +0000210 }
Richard Burton4e002792022-05-04 09:45:02 +0100211
212 /* At this point substring should be 00.wav */
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000213 delimiter = "."; /* Second character used to split the file name up. */
Richard Burton4e002792022-05-04 09:45:02 +0100214 delimiterStart = subString.find(delimiter);
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000215 subString =
216 (delimiterStart != std::string::npos) ? subString.substr(0, delimiterStart) : subString;
Richard Burton4e002792022-05-04 09:45:02 +0100217
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000218 auto is_number = [](const std::string& str) -> bool {
Richard Burton4e002792022-05-04 09:45:02 +0100219 std::string::const_iterator it = str.begin();
Kshitij Sisodia2ea46232022-12-19 16:37:33 +0000220 while (it != str.end() && std::isdigit(*it))
221 ++it;
Richard Burton4e002792022-05-04 09:45:02 +0100222 return !str.empty() && it == str.end();
223 };
224
225 const int8_t machineIdx = is_number(subString) ? std::stoi(subString) : -1;
226
227 /* Return corresponding index in the output vector. */
228 if (machineIdx == 0) {
229 return 0;
230 } else if (machineIdx == 2) {
231 return 1;
232 } else if (machineIdx == 4) {
233 return 2;
234 } else if (machineIdx == 6) {
235 return 3;
236 } else {
237 printf_err("%d is an invalid machine index \n", machineIdx);
238 return -1;
239 }
alexander3c798932021-03-26 21:42:19 +0000240 }
241
242} /* namespace app */
243} /* namespace arm */