blob: a33131ca15e0b2f8f82a2361089481ec5500df42 [file] [log] [blame]
Richard Burton4e002792022-05-04 09:45:02 +01001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#include "AdProcessing.hpp"
18
19#include "AdModel.hpp"
20
21namespace arm {
22namespace app {
23
24AdPreProcess::AdPreProcess(TfLiteTensor* inputTensor,
25 uint32_t melSpectrogramFrameLen,
26 uint32_t melSpectrogramFrameStride,
27 float adModelTrainingMean):
28 m_validInstance{false},
29 m_melSpectrogramFrameLen{melSpectrogramFrameLen},
30 m_melSpectrogramFrameStride{melSpectrogramFrameStride},
31 /**< Model is trained on features downsampled 2x */
32 m_inputResizeScale{2},
33 /**< We are choosing to move by 20 frames across the audio for each inference. */
34 m_numMelSpecVectorsInAudioStride{20},
35 m_audioDataStride{m_numMelSpecVectorsInAudioStride * melSpectrogramFrameStride},
36 m_melSpec{melSpectrogramFrameLen}
37{
38 if (!inputTensor) {
39 printf_err("Invalid input tensor provided to pre-process\n");
40 return;
41 }
42
43 TfLiteIntArray* inputShape = inputTensor->dims;
44
45 if (!inputShape) {
46 printf_err("Invalid input tensor dims\n");
47 return;
48 }
49
50 const uint32_t kNumRows = inputShape->data[AdModel::ms_inputRowsIdx];
51 const uint32_t kNumCols = inputShape->data[AdModel::ms_inputColsIdx];
52
53 /* Deduce the data length required for 1 inference from the network parameters. */
54 this->m_audioDataWindowSize = (((this->m_inputResizeScale * kNumCols) - 1) *
55 melSpectrogramFrameStride) +
56 melSpectrogramFrameLen;
57 this->m_numReusedFeatureVectors = kNumRows -
58 (this->m_numMelSpecVectorsInAudioStride /
59 this->m_inputResizeScale);
60 this->m_melSpec.Init();
61
62 /* Creating a Mel Spectrogram sliding window for the data required for 1 inference.
63 * "resizing" done here by multiplying stride by resize scale. */
64 this->m_melWindowSlider = audio::SlidingWindow<const int16_t>(
65 nullptr, /* to be populated later. */
66 this->m_audioDataWindowSize,
67 melSpectrogramFrameLen,
68 melSpectrogramFrameStride * this->m_inputResizeScale);
69
70 /* Construct feature calculation function. */
71 this->m_featureCalc = GetFeatureCalculator(this->m_melSpec, inputTensor,
72 this->m_numReusedFeatureVectors,
73 adModelTrainingMean);
74 this->m_validInstance = true;
75}
76
77bool AdPreProcess::DoPreProcess(const void* input, size_t inputSize)
78{
79 /* Check that we have a valid instance. */
80 if (!this->m_validInstance) {
81 printf_err("Invalid pre-processor instance\n");
82 return false;
83 }
84
85 /* We expect that we can traverse the size with which the MEL spectrogram
86 * sliding window was initialised with. */
87 if (!input || inputSize < this->m_audioDataWindowSize) {
88 printf_err("Invalid input provided for pre-processing\n");
89 return false;
90 }
91
92 /* We moved to the next window - set the features sliding to the new address. */
93 this->m_melWindowSlider.Reset(static_cast<const int16_t*>(input));
94
95 /* The first window does not have cache ready. */
96 const bool useCache = this->m_audioWindowIndex > 0 && this->m_numReusedFeatureVectors > 0;
97
98 /* Start calculating features inside one audio sliding window. */
99 while (this->m_melWindowSlider.HasNext()) {
100 const int16_t* melSpecWindow = this->m_melWindowSlider.Next();
101 std::vector<int16_t> melSpecAudioData = std::vector<int16_t>(
102 melSpecWindow,
103 melSpecWindow + this->m_melSpectrogramFrameLen);
104
105 /* Compute features for this window and write them to input tensor. */
106 this->m_featureCalc(melSpecAudioData,
107 this->m_melWindowSlider.Index(),
108 useCache,
109 this->m_numMelSpecVectorsInAudioStride,
110 this->m_inputResizeScale);
111 }
112
113 return true;
114}
115
116uint32_t AdPreProcess::GetAudioWindowSize()
117{
118 return this->m_audioDataWindowSize;
119}
120
121uint32_t AdPreProcess::GetAudioDataStride()
122{
123 return this->m_audioDataStride;
124}
125
126void AdPreProcess::SetAudioWindowIndex(uint32_t idx)
127{
128 this->m_audioWindowIndex = idx;
129}
130
131AdPostProcess::AdPostProcess(TfLiteTensor* outputTensor) :
132 m_outputTensor {outputTensor}
133{}
134
135bool AdPostProcess::DoPostProcess()
136{
137 switch (this->m_outputTensor->type) {
138 case kTfLiteInt8:
139 this->Dequantize<int8_t>();
140 break;
141 default:
142 printf_err("Unsupported tensor type");
143 return false;
144 }
145
146 math::MathUtils::SoftmaxF32(this->m_dequantizedOutputVec);
147 return true;
148}
149
150float AdPostProcess::GetOutputValue(uint32_t index)
151{
152 if (index < this->m_dequantizedOutputVec.size()) {
153 return this->m_dequantizedOutputVec[index];
154 }
155 printf_err("Invalid index for output\n");
156 return 0.0;
157}
158
159std::function<void (std::vector<int16_t>&, int, bool, size_t, size_t)>
160GetFeatureCalculator(audio::AdMelSpectrogram& melSpec,
161 TfLiteTensor* inputTensor,
162 size_t cacheSize,
163 float trainingMean)
164{
165 std::function<void (std::vector<int16_t>&, size_t, bool, size_t, size_t)> melSpecFeatureCalc;
166
167 TfLiteQuantization quant = inputTensor->quantization;
168
169 if (kTfLiteAffineQuantization == quant.type) {
170
171 auto* quantParams = static_cast<TfLiteAffineQuantization*>(quant.params);
172 const float quantScale = quantParams->scale->data[0];
173 const int quantOffset = quantParams->zero_point->data[0];
174
175 switch (inputTensor->type) {
176 case kTfLiteInt8: {
177 melSpecFeatureCalc = FeatureCalc<int8_t>(
178 inputTensor,
179 cacheSize,
180 [=, &melSpec](std::vector<int16_t>& audioDataWindow) {
181 return melSpec.MelSpecComputeQuant<int8_t>(
182 audioDataWindow,
183 quantScale,
184 quantOffset,
185 trainingMean);
186 }
187 );
188 break;
189 }
190 default:
191 printf_err("Tensor type %s not supported\n", TfLiteTypeGetName(inputTensor->type));
192 }
193 } else {
194 melSpecFeatureCalc = FeatureCalc<float>(
195 inputTensor,
196 cacheSize,
197 [=, &melSpec](
198 std::vector<int16_t>& audioDataWindow) {
199 return melSpec.ComputeMelSpec(
200 audioDataWindow,
201 trainingMean);
202 });
203 }
204 return melSpecFeatureCalc;
205}
206
207} /* namespace app */
208} /* namespace arm */