blob: 829bf2baa7d1bfddca5927a91f13faaa3224ea67 [file] [log] [blame]
Richard Burtone6398cd2022-04-13 11:58:28 +01001/*
Richard Burtonf32a86a2022-11-15 11:46:11 +00002 * SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
Richard Burtone6398cd2022-04-13 11:58:28 +01003 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef KWS_PROCESSING_HPP
18#define KWS_PROCESSING_HPP
19
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010020#include "AudioUtils.hpp"
Richard Burtone6398cd2022-04-13 11:58:28 +010021#include "BaseProcessing.hpp"
Richard Burtonec5e99b2022-10-05 11:00:37 +010022#include "KwsClassifier.hpp"
Richard Burtone6398cd2022-04-13 11:58:28 +010023#include "MicroNetKwsMfcc.hpp"
24
25#include <functional>
26
27namespace arm {
28namespace app {
29
30 /**
31 * @brief Pre-processing class for Keyword Spotting use case.
32 * Implements methods declared by BasePreProcess and anything else needed
33 * to populate input tensors ready for inference.
34 */
Richard Burtonb40ecf82022-04-22 16:14:57 +010035 class KwsPreProcess : public BasePreProcess {
Richard Burtone6398cd2022-04-13 11:58:28 +010036
37 public:
38 /**
39 * @brief Constructor
Richard Burtonb40ecf82022-04-22 16:14:57 +010040 * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
41 * @param[in] numFeatures How many MFCC features to use.
42 * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated
43 * for an inference.
44 * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
45 * sliding a window through the audio sample.
46 * @param[in] mfccFrameStride Number of audio samples between consecutive windows.
Richard Burtone6398cd2022-04-13 11:58:28 +010047 **/
Richard Burtonb40ecf82022-04-22 16:14:57 +010048 explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames,
49 int mfccFrameLength, int mfccFrameStride);
Richard Burtone6398cd2022-04-13 11:58:28 +010050
51 /**
52 * @brief Should perform pre-processing of 'raw' input audio data and load it into
53 * TFLite Micro input tensors ready for inference.
54 * @param[in] input Pointer to the data that pre-processing will work on.
55 * @param[in] inputSize Size of the input data.
56 * @return true if successful, false otherwise.
57 **/
Richard Burtonec5e99b2022-10-05 11:00:37 +010058 bool DoPreProcess(const void* input, size_t inferenceIndex = 0) override;
Richard Burtone6398cd2022-04-13 11:58:28 +010059
Richard Burtone6398cd2022-04-13 11:58:28 +010060 size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */
61 size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */
62
63 private:
Richard Burtonb40ecf82022-04-22 16:14:57 +010064 TfLiteTensor* m_inputTensor; /* Model input tensor. */
Richard Burtone6398cd2022-04-13 11:58:28 +010065 const int m_mfccFrameLength;
66 const int m_mfccFrameStride;
Richard Burtonb40ecf82022-04-22 16:14:57 +010067 const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */
Richard Burtone6398cd2022-04-13 11:58:28 +010068
69 audio::MicroNetKwsMFCC m_mfcc;
70 audio::SlidingWindow<const int16_t> m_mfccSlidingWindow;
71 size_t m_numMfccVectorsInAudioStride;
72 size_t m_numReusedMfccVectors;
73 std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator;
74
75 /**
76 * @brief Returns a function to perform feature calculation and populates input tensor data with
77 * MFCC data.
78 *
79 * Input tensor data type check is performed to choose correct MFCC feature data type.
80 * If tensor has an integer data type then original features are quantised.
81 *
82 * Warning: MFCC calculator provided as input must have the same life scope as returned function.
83 *
84 * @param[in] mfcc MFCC feature calculator.
85 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
86 * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
87 * @return Function to be called providing audio sample and sliding window index.
88 */
89 std::function<void (std::vector<int16_t>&, int, bool, size_t)>
90 GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
91 TfLiteTensor* inputTensor,
92 size_t cacheSize);
93
94 template<class T>
95 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
96 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
97 std::function<std::vector<T> (std::vector<int16_t>& )> compute);
98 };
99
100 /**
101 * @brief Post-processing class for Keyword Spotting use case.
102 * Implements methods declared by BasePostProcess and anything else needed
103 * to populate result vector.
104 */
Richard Burtonb40ecf82022-04-22 16:14:57 +0100105 class KwsPostProcess : public BasePostProcess {
Richard Burtone6398cd2022-04-13 11:58:28 +0100106
107 private:
Richard Burtonec5e99b2022-10-05 11:00:37 +0100108 TfLiteTensor* m_outputTensor; /* Model output tensor. */
109 KwsClassifier& m_kwsClassifier; /* KWS Classifier object. */
110 const std::vector<std::string>& m_labels; /* KWS Labels. */
111 std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */
112 std::vector<std::vector<float>> m_resultHistory; /* Store previous results so they can be averaged. */
Richard Burtone6398cd2022-04-13 11:58:28 +0100113 public:
Richard Burtone6398cd2022-04-13 11:58:28 +0100114 /**
Richard Burtonc2911442022-04-22 09:08:21 +0100115 * @brief Constructor
Richard Burtonb40ecf82022-04-22 16:14:57 +0100116 * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
117 * @param[in] classifier Classifier object used to get top N results from classification.
118 * @param[in] labels Vector of string labels to identify each output of the model.
119 * @param[in/out] results Vector of classification results to store decoded outputs.
Richard Burtone6398cd2022-04-13 11:58:28 +0100120 **/
Richard Burtonec5e99b2022-10-05 11:00:37 +0100121 KwsPostProcess(TfLiteTensor* outputTensor, KwsClassifier& classifier,
Richard Burtone6398cd2022-04-13 11:58:28 +0100122 const std::vector<std::string>& labels,
Richard Burtonec5e99b2022-10-05 11:00:37 +0100123 std::vector<ClassificationResult>& results, size_t averagingWindowLen = 1);
Richard Burtone6398cd2022-04-13 11:58:28 +0100124
125 /**
Richard Burtonc2911442022-04-22 09:08:21 +0100126 * @brief Should perform post-processing of the result of inference then
127 * populate KWS result data for any later use.
128 * @return true if successful, false otherwise.
Richard Burtone6398cd2022-04-13 11:58:28 +0100129 **/
130 bool DoPostProcess() override;
131 };
132
133} /* namespace app */
134} /* namespace arm */
135
136#endif /* KWS_PROCESSING_HPP */