blob: 04c9381da2d42cc5f36860c9e28d81799905f886 [file] [log] [blame]
Richard Burton4e002792022-05-04 09:45:02 +01001/*
Richard Burtonf32a86a2022-11-15 11:46:11 +00002 * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Richard Burton4e002792022-05-04 09:45:02 +01003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef AD_PROCESSING_HPP
18#define AD_PROCESSING_HPP
19
20#include "BaseProcessing.hpp"
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010021#include "TensorFlowLiteMicro.hpp"
Richard Burton4e002792022-05-04 09:45:02 +010022#include "AudioUtils.hpp"
23#include "AdMelSpectrogram.hpp"
24#include "log_macros.h"
25
26namespace arm {
27namespace app {
28
29 /**
30 * @brief Pre-processing class for anomaly detection use case.
31 * Implements methods declared by BasePreProcess and anything else needed
32 * to populate input tensors ready for inference.
33 */
34 class AdPreProcess : public BasePreProcess {
35
36 public:
37 /**
38 * @brief Constructor for AdPreProcess class objects
39 * @param[in] inputTensor input tensor pointer from the tensor arena.
40 * @param[in] melSpectrogramFrameLen MEL spectrogram's frame length
41 * @param[in] melSpectrogramFrameStride MEL spectrogram's frame stride
42 * @param[in] adModelTrainingMean Training mean for the Anomaly detection model being used.
43 */
44 explicit AdPreProcess(TfLiteTensor* inputTensor,
45 uint32_t melSpectrogramFrameLen,
46 uint32_t melSpectrogramFrameStride,
47 float adModelTrainingMean);
48
49 ~AdPreProcess() = default;
50
51 /**
52 * @brief Function to invoke pre-processing and populate the input vector
53 * @param input pointer to input data. For anomaly detection, this is the pointer to
54 * the audio data.
55 * @param inputSize Size of the data being passed in for pre-processing.
56 * @return True if successful, false otherwise.
57 */
58 bool DoPreProcess(const void* input, size_t inputSize) override;
59
60 /**
61 * @brief Getter function for audio window size computed when constructing
62 * the class object.
63 * @return Audio window size as 32 bit unsigned integer.
64 */
65 uint32_t GetAudioWindowSize();
66
67 /**
68 * @brief Getter function for audio window stride computed when constructing
69 * the class object.
70 * @return Audio window stride as 32 bit unsigned integer.
71 */
72 uint32_t GetAudioDataStride();
73
74 /**
75 * @brief Setter function for current audio index. This is only used for evaluating
76 * if previously computed features can be re-used from cache.
77 */
78 void SetAudioWindowIndex(uint32_t idx);
79
80 private:
81 bool m_validInstance{false}; /**< Indicates the current object is valid. */
82 uint32_t m_melSpectrogramFrameLen{}; /**< MEL spectrogram's window frame length */
83 uint32_t m_melSpectrogramFrameStride{}; /**< MEL spectrogram's window frame stride */
84 uint8_t m_inputResizeScale{}; /**< Downscaling factor for the MEL energy matrix. */
85 uint32_t m_numMelSpecVectorsInAudioStride{}; /**< Number of frames to move across the audio. */
86 uint32_t m_audioDataWindowSize{}; /**< Audio window size computed based on other parameters. */
87 uint32_t m_audioDataStride{}; /**< Audio window stride computed. */
88 uint32_t m_numReusedFeatureVectors{}; /**< Number of MEL vectors that can be re-used */
89 uint32_t m_audioWindowIndex{}; /**< Current audio window index (from audio's sliding window) */
90
91 audio::SlidingWindow<const int16_t> m_melWindowSlider; /**< Internal MEL spectrogram window slider */
92 audio::AdMelSpectrogram m_melSpec; /**< MEL spectrogram computation object */
93 std::function<void
94 (std::vector<int16_t>&, int, bool, size_t, size_t)> m_featureCalc; /**< Feature calculator object */
95 };
96
97 class AdPostProcess : public BasePostProcess {
98 public:
99 /**
100 * @brief Constructor for AdPostProcess object.
101 * @param[in] outputTensor Output tensor pointer.
102 */
103 explicit AdPostProcess(TfLiteTensor* outputTensor);
104
105 ~AdPostProcess() = default;
106
107 /**
108 * @brief Function to do the post-processing on the output tensor.
109 * @return True if successful, false otherwise.
110 */
111 bool DoPostProcess() override;
112
113 /**
114 * @brief Getter function for an element from the de-quantised output vector.
115 * @param index Index of the element to be retrieved.
116 * @return index represented as a 32 bit floating point number.
117 */
118 float GetOutputValue(uint32_t index);
119
120 private:
121 TfLiteTensor* m_outputTensor{}; /**< Output tensor pointer */
122 std::vector<float> m_dequantizedOutputVec{}; /**< Internal output vector */
123
124 /**
125 * @brief De-quantizes and flattens the output tensor into a vector.
126 * @tparam T template parameter to indicate data type.
127 * @return True if successful, false otherwise.
128 */
129 template<typename T>
130 bool Dequantize()
131 {
132 TfLiteTensor* tensor = this->m_outputTensor;
133 if (tensor == nullptr) {
134 printf_err("Invalid output tensor.\n");
135 return false;
136 }
137 T* tensorData = tflite::GetTensorData<T>(tensor);
138
139 uint32_t totalOutputSize = 1;
140 for (int inputDim = 0; inputDim < tensor->dims->size; inputDim++){
141 totalOutputSize *= tensor->dims->data[inputDim];
142 }
143
144 /* For getting the floating point values, we need quantization parameters */
145 QuantParams quantParams = GetTensorQuantParams(tensor);
146
147 this->m_dequantizedOutputVec = std::vector<float>(totalOutputSize, 0);
148
149 for (size_t i = 0; i < totalOutputSize; ++i) {
150 this->m_dequantizedOutputVec[i] = quantParams.scale * (tensorData[i] - quantParams.offset);
151 }
152
153 return true;
154 }
155 };
156
157 /* Templated instances available: */
158 template bool AdPostProcess::Dequantize<int8_t>();
159
160 /**
161 * @brief Generic feature calculator factory.
162 *
163 * Returns lambda function to compute features using features cache.
164 * Real features math is done by a lambda function provided as a parameter.
165 * Features are written to input tensor memory.
166 *
167 * @tparam T feature vector type.
168 * @param inputTensor model input tensor pointer.
169 * @param cacheSize number of feature vectors to cache. Defined by the sliding window overlap.
170 * @param compute features calculator function.
171 * @return lambda function to compute features.
172 */
173 template<class T>
174 std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)>
175 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
176 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
177 {
178 /* Feature cache to be captured by lambda function*/
179 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
180
181 return [=](std::vector<int16_t>& audioDataWindow,
182 size_t index,
183 bool useCache,
184 size_t featuresOverlapIndex,
185 size_t resizeScale)
186 {
187 T* tensorData = tflite::GetTensorData<T>(inputTensor);
188 std::vector<T> features;
189
190 /* Reuse features from cache if cache is ready and sliding windows overlap.
191 * Overlap is in the beginning of sliding window with a size of a feature cache. */
192 if (useCache && index < featureCache.size()) {
193 features = std::move(featureCache[index]);
194 } else {
195 features = std::move(compute(audioDataWindow));
196 }
197 auto size = features.size() / resizeScale;
198 auto sizeBytes = sizeof(T);
199
200 /* Input should be transposed and "resized" by skipping elements. */
201 for (size_t outIndex = 0; outIndex < size; outIndex++) {
202 std::memcpy(tensorData + (outIndex*size) + index, &features[outIndex*resizeScale], sizeBytes);
203 }
204
205 /* Start renewing cache as soon iteration goes out of the windows overlap. */
206 if (index >= featuresOverlapIndex / resizeScale) {
207 featureCache[index - featuresOverlapIndex / resizeScale] = std::move(features);
208 }
209 };
210 }
211
212 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t, size_t)>
213 FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
214 size_t cacheSize,
215 std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
216
217 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t, size_t)>
218 FeatureCalc<float>(TfLiteTensor *inputTensor,
219 size_t cacheSize,
220 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
221
222 std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
223 GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
224 TfLiteTensor* inputTensor,
225 size_t cacheSize,
226 float trainingMean);
227
228} /* namespace app */
229} /* namespace arm */
230
231#endif /* AD_PROCESSING_HPP */