blob: b6b230c1dd37600a49966d7dbf303ef7162221dd [file] [log] [blame]
Richard Burtone6398cd2022-04-13 11:58:28 +01001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "KwsProcessing.hpp"
18#include "ImageUtils.hpp"
19#include "log_macros.h"
20#include "MicroNetKwsModel.hpp"
21
22namespace arm {
23namespace app {
24
25 KWSPreProcess::KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride):
26 m_mfccFrameLength{mfccFrameLength},
27 m_mfccFrameStride{mfccFrameStride},
28 m_mfcc{audio::MicroNetKwsMFCC(numFeatures, mfccFrameLength)}
29 {
30 if (!model->IsInited()) {
31 printf_err("Model is not initialised!.\n");
32 }
33 this->m_model = model;
34 this->m_mfcc.Init();
35
36 TfLiteIntArray* inputShape = model->GetInputShape(0);
37 const uint32_t numMfccFrames = inputShape->data[arm::app::MicroNetKwsModel::ms_inputRowsIdx];
38
39 /* Deduce the data length required for 1 inference from the network parameters. */
40 this->m_audioDataWindowSize = numMfccFrames * this->m_mfccFrameStride +
41 (this->m_mfccFrameLength - this->m_mfccFrameStride);
42
43 /* Creating an MFCC feature sliding window for the data required for 1 inference. */
44 this->m_mfccSlidingWindow = audio::SlidingWindow<const int16_t>(nullptr, this->m_audioDataWindowSize,
45 this->m_mfccFrameLength, this->m_mfccFrameStride);
46
47 /* For longer audio clips we choose to move by half the audio window size
48 * => for a 1 second window size there is an overlap of 0.5 seconds. */
49 this->m_audioDataStride = this->m_audioDataWindowSize / 2;
50
51 /* To have the previously calculated features re-usable, stride must be multiple
52 * of MFCC features window stride. Reduce stride through audio if needed. */
53 if (0 != this->m_audioDataStride % this->m_mfccFrameStride) {
54 this->m_audioDataStride -= this->m_audioDataStride % this->m_mfccFrameStride;
55 }
56
57 this->m_numMfccVectorsInAudioStride = this->m_audioDataStride / this->m_mfccFrameStride;
58
59 /* Calculate number of the feature vectors in the window overlap region.
60 * These feature vectors will be reused.*/
61 this->m_numReusedMfccVectors = this->m_mfccSlidingWindow.TotalStrides() + 1
62 - this->m_numMfccVectorsInAudioStride;
63
64 /* Construct feature calculation function. */
65 this->m_mfccFeatureCalculator = GetFeatureCalculator(this->m_mfcc, this->m_model->GetInputTensor(0),
66 this->m_numReusedMfccVectors);
67
68 if (!this->m_mfccFeatureCalculator) {
69 printf_err("Feature calculator not initialized.");
70 }
71 }
72
73 bool KWSPreProcess::DoPreProcess(const void* data, size_t inputSize)
74 {
75 UNUSED(inputSize);
76 if (data == nullptr) {
77 printf_err("Data pointer is null");
78 }
79
80 /* Set the features sliding window to the new address. */
81 auto input = static_cast<const int16_t*>(data);
82 this->m_mfccSlidingWindow.Reset(input);
83
84 /* Cache is only usable if we have more than 1 inference in an audio clip. */
85 bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedMfccVectors > 0;
86
87 /* Use a sliding window to calculate MFCC features frame by frame. */
88 while (this->m_mfccSlidingWindow.HasNext()) {
89 const int16_t* mfccWindow = this->m_mfccSlidingWindow.Next();
90
91 std::vector<int16_t> mfccFrameAudioData = std::vector<int16_t>(mfccWindow,
92 mfccWindow + this->m_mfccFrameLength);
93
94 /* Compute features for this window and write them to input tensor. */
95 this->m_mfccFeatureCalculator(mfccFrameAudioData, this->m_mfccSlidingWindow.Index(),
96 useCache, this->m_numMfccVectorsInAudioStride);
97 }
98
99 debug("Input tensor populated \n");
100
101 return true;
102 }
103
104 /**
105 * @brief Generic feature calculator factory.
106 *
107 * Returns lambda function to compute features using features cache.
108 * Real features math is done by a lambda function provided as a parameter.
109 * Features are written to input tensor memory.
110 *
111 * @tparam T Feature vector type.
112 * @param[in] inputTensor Model input tensor pointer.
113 * @param[in] cacheSize Number of feature vectors to cache. Defined by the sliding window overlap.
114 * @param[in] compute Features calculator function.
115 * @return Lambda function to compute features.
116 */
117 template<class T>
118 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
119 KWSPreProcess::FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
120 std::function<std::vector<T> (std::vector<int16_t>& )> compute)
121 {
122 /* Feature cache to be captured by lambda function. */
123 static std::vector<std::vector<T>> featureCache = std::vector<std::vector<T>>(cacheSize);
124
125 return [=](std::vector<int16_t>& audioDataWindow,
126 size_t index,
127 bool useCache,
128 size_t featuresOverlapIndex)
129 {
130 T* tensorData = tflite::GetTensorData<T>(inputTensor);
131 std::vector<T> features;
132
133 /* Reuse features from cache if cache is ready and sliding windows overlap.
134 * Overlap is in the beginning of sliding window with a size of a feature cache. */
135 if (useCache && index < featureCache.size()) {
136 features = std::move(featureCache[index]);
137 } else {
138 features = std::move(compute(audioDataWindow));
139 }
140 auto size = features.size();
141 auto sizeBytes = sizeof(T) * size;
142 std::memcpy(tensorData + (index * size), features.data(), sizeBytes);
143
144 /* Start renewing cache as soon iteration goes out of the windows overlap. */
145 if (index >= featuresOverlapIndex) {
146 featureCache[index - featuresOverlapIndex] = std::move(features);
147 }
148 };
149 }
150
151 template std::function<void (std::vector<int16_t>&, size_t , bool, size_t)>
152 KWSPreProcess::FeatureCalc<int8_t>(TfLiteTensor* inputTensor,
153 size_t cacheSize,
154 std::function<std::vector<int8_t> (std::vector<int16_t>&)> compute);
155
156 template std::function<void(std::vector<int16_t>&, size_t, bool, size_t)>
157 KWSPreProcess::FeatureCalc<float>(TfLiteTensor* inputTensor,
158 size_t cacheSize,
159 std::function<std::vector<float>(std::vector<int16_t>&)> compute);
160
161
162 std::function<void (std::vector<int16_t>&, int, bool, size_t)>
163 KWSPreProcess::GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc, TfLiteTensor* inputTensor, size_t cacheSize)
164 {
165 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)> mfccFeatureCalc;
166
167 TfLiteQuantization quant = inputTensor->quantization;
168
169 if (kTfLiteAffineQuantization == quant.type) {
170 auto *quantParams = (TfLiteAffineQuantization *) quant.params;
171 const float quantScale = quantParams->scale->data[0];
172 const int quantOffset = quantParams->zero_point->data[0];
173
174 switch (inputTensor->type) {
175 case kTfLiteInt8: {
176 mfccFeatureCalc = this->FeatureCalc<int8_t>(inputTensor,
177 cacheSize,
178 [=, &mfcc](std::vector<int16_t>& audioDataWindow) {
179 return mfcc.MfccComputeQuant<int8_t>(audioDataWindow,
180 quantScale,
181 quantOffset);
182 }
183 );
184 break;
185 }
186 default:
187 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
188 }
189 } else {
190 mfccFeatureCalc = this->FeatureCalc<float>(inputTensor, cacheSize,
191 [&mfcc](std::vector<int16_t>& audioDataWindow) {
192 return mfcc.MfccCompute(audioDataWindow); }
193 );
194 }
195 return mfccFeatureCalc;
196 }
197
198 KWSPostProcess::KWSPostProcess(Classifier& classifier, Model* model,
199 const std::vector<std::string>& labels,
200 std::vector<ClassificationResult>& results, float scoreThreshold)
201 :m_kwsClassifier{classifier},
202 m_labels{labels},
203 m_results{results},
204 m_scoreThreshold{scoreThreshold}
205 {
206 if (!model->IsInited()) {
207 printf_err("Model is not initialised!.\n");
208 }
209 this->m_model = model;
210 }
211
212 bool KWSPostProcess::DoPostProcess()
213 {
214 return this->m_kwsClassifier.GetClassificationResults(
215 this->m_model->GetOutputTensor(0), this->m_results,
216 this->m_labels, 1, true);
217 }
218
219} /* namespace app */
220} /* namespace arm */