blob: 5a85e963968a11e927a7a755ecf118d3ed4cc2d5 [file] [log] [blame]
Richard Burton4e002792022-05-04 09:45:02 +01001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "AdProcessing.hpp"
18
19#include "AdModel.hpp"
20
21namespace arm {
22namespace app {
23
24AdPreProcess::AdPreProcess(TfLiteTensor* inputTensor,
25 uint32_t melSpectrogramFrameLen,
26 uint32_t melSpectrogramFrameStride,
27 float adModelTrainingMean):
28 m_validInstance{false},
29 m_melSpectrogramFrameLen{melSpectrogramFrameLen},
30 m_melSpectrogramFrameStride{melSpectrogramFrameStride},
31 /**< Model is trained on features downsampled 2x */
32 m_inputResizeScale{2},
33 /**< We are choosing to move by 20 frames across the audio for each inference. */
34 m_numMelSpecVectorsInAudioStride{20},
35 m_audioDataStride{m_numMelSpecVectorsInAudioStride * melSpectrogramFrameStride},
36 m_melSpec{melSpectrogramFrameLen}
37{
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010038 UNUSED(this->m_melSpectrogramFrameStride);
39
Richard Burton4e002792022-05-04 09:45:02 +010040 if (!inputTensor) {
41 printf_err("Invalid input tensor provided to pre-process\n");
42 return;
43 }
44
45 TfLiteIntArray* inputShape = inputTensor->dims;
46
47 if (!inputShape) {
48 printf_err("Invalid input tensor dims\n");
49 return;
50 }
51
52 const uint32_t kNumRows = inputShape->data[AdModel::ms_inputRowsIdx];
53 const uint32_t kNumCols = inputShape->data[AdModel::ms_inputColsIdx];
54
55 /* Deduce the data length required for 1 inference from the network parameters. */
56 this->m_audioDataWindowSize = (((this->m_inputResizeScale * kNumCols) - 1) *
57 melSpectrogramFrameStride) +
58 melSpectrogramFrameLen;
59 this->m_numReusedFeatureVectors = kNumRows -
60 (this->m_numMelSpecVectorsInAudioStride /
61 this->m_inputResizeScale);
62 this->m_melSpec.Init();
63
64 /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
65 * "resizing" done here by multiplying stride by resize scale. */
66 this->m_melWindowSlider = audio::SlidingWindow<const int16_t>(
67 nullptr, /* to be populated later. */
68 this->m_audioDataWindowSize,
69 melSpectrogramFrameLen,
70 melSpectrogramFrameStride * this->m_inputResizeScale);
71
72 /* Construct feature calculation function. */
73 this->m_featureCalc = GetFeatureCalculator(this->m_melSpec, inputTensor,
74 this->m_numReusedFeatureVectors,
75 adModelTrainingMean);
76 this->m_validInstance = true;
77}
78
79bool AdPreProcess::DoPreProcess(const void* input, size_t inputSize)
80{
81 /* Check that we have a valid instance. */
82 if (!this->m_validInstance) {
83 printf_err("Invalid pre-processor instance\n");
84 return false;
85 }
86
87 /* We expect that we can traverse the size with which the MEL spectrogram
88 * sliding window was initialised with. */
89 if (!input || inputSize < this->m_audioDataWindowSize) {
90 printf_err("Invalid input provided for pre-processing\n");
91 return false;
92 }
93
94 /* We moved to the next window - set the features sliding to the new address. */
95 this->m_melWindowSlider.Reset(static_cast<const int16_t*>(input));
96
97 /* The first window does not have cache ready. */
98 const bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedFeatureVectors > 0;
99
100 /* Start calculating features inside one audio sliding window. */
101 while (this->m_melWindowSlider.HasNext()) {
102 const int16_t* melSpecWindow = this->m_melWindowSlider.Next();
103 std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(
104 melSpecWindow,
105 melSpecWindow + this->m_melSpectrogramFrameLen);
106
107 /* Compute features for this window and write them to input tensor. */
108 this->m_featureCalc(melSpecAudioData,
109 this->m_melWindowSlider.Index(),
110 useCache,
111 this->m_numMelSpecVectorsInAudioStride,
112 this->m_inputResizeScale);
113 }
114
115 return true;
116}
117
118uint32_t AdPreProcess::GetAudioWindowSize()
119{
120 return this->m_audioDataWindowSize;
121}
122
123uint32_t AdPreProcess::GetAudioDataStride()
124{
125 return this->m_audioDataStride;
126}
127
128void AdPreProcess::SetAudioWindowIndex(uint32_t idx)
129{
130 this->m_audioWindowIndex = idx;
131}
132
133AdPostProcess::AdPostProcess(TfLiteTensor* outputTensor) :
134 m_outputTensor {outputTensor}
135{}
136
137bool AdPostProcess::DoPostProcess()
138{
139 switch (this->m_outputTensor->type) {
140 case kTfLiteInt8:
141 this->Dequantize<int8_t>();
142 break;
143 default:
144 printf_err("Unsupported tensor type");
145 return false;
146 }
147
148 math::MathUtils::SoftmaxF32(this->m_dequantizedOutputVec);
149 return true;
150}
151
152float AdPostProcess::GetOutputValue(uint32_t index)
153{
154 if (index < this->m_dequantizedOutputVec.size()) {
155 return this->m_dequantizedOutputVec[index];
156 }
157 printf_err("Invalid index for output\n");
158 return 0.0;
159}
160
161std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
162GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
163 TfLiteTensor* inputTensor,
164 size_t cacheSize,
165 float trainingMean)
166{
Maksims Svecovs154a2b12022-08-30 11:53:19 +0100167 std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc = nullptr;
Richard Burton4e002792022-05-04 09:45:02 +0100168
169 TfLiteQuantization quant = inputTensor->quantization;
170
171 if (kTfLiteAffineQuantization == quant.type) {
172
173 auto* quantParams = static_cast<TfLiteAffineQuantization*>(quant.params);
174 const float quantScale = quantParams->scale->data[0];
175 const int quantOffset = quantParams->zero_point->data[0];
176
177 switch (inputTensor->type) {
178 case kTfLiteInt8: {
179 melSpecFeatureCalc = FeatureCalc<int8_t>(
180 inputTensor,
181 cacheSize,
182 [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
183 return melSpec.MelSpecComputeQuant<int8_t>(
184 audioDataWindow,
185 quantScale,
186 quantOffset,
187 trainingMean);
188 }
189 );
190 break;
191 }
192 default:
193 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
194 }
195 } else {
196 melSpecFeatureCalc = FeatureCalc<float>(
197 inputTensor,
198 cacheSize,
199 [=, &melSpec](
200 std::vector<int16_t>& audioDataWindow) {
201 return melSpec.ComputeMelSpec(
202 audioDataWindow,
203 trainingMean);
204 });
205 }
206 return melSpecFeatureCalc;
207}
208
209} /* namespace app */
210} /* namespace arm */