blob: 0ede425b235dab23ef5e1d7807dda8d68a518c79 [file] [log] [blame]
Richard Burtone6398cd2022-04-13 11:58:28 +01001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef KWS_PROCESSING_HPP
18#define KWS_PROCESSING_HPP
19
Kshitij Sisodiaaa4bcb12022-05-06 09:13:03 +010020#include "AudioUtils.hpp"
Richard Burtone6398cd2022-04-13 11:58:28 +010021#include "BaseProcessing.hpp"
Richard Burtone6398cd2022-04-13 11:58:28 +010022#include "Classifier.hpp"
23#include "MicroNetKwsMfcc.hpp"
24
25#include <functional>
26
27namespace arm {
28namespace app {
29
30 /**
31 * @brief Pre-processing class for Keyword Spotting use case.
32 * Implements methods declared by BasePreProcess and anything else needed
33 * to populate input tensors ready for inference.
34 */
Richard Burtonb40ecf82022-04-22 16:14:57 +010035 class KwsPreProcess : public BasePreProcess {
Richard Burtone6398cd2022-04-13 11:58:28 +010036
37 public:
38 /**
39 * @brief Constructor
Richard Burtonb40ecf82022-04-22 16:14:57 +010040 * @param[in] inputTensor Pointer to the TFLite Micro input Tensor.
41 * @param[in] numFeatures How many MFCC features to use.
42 * @param[in] numFeatureFrames Number of MFCC vectors that need to be calculated
43 * for an inference.
44 * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
45 * sliding a window through the audio sample.
46 * @param[in] mfccFrameStride Number of audio samples between consecutive windows.
Richard Burtone6398cd2022-04-13 11:58:28 +010047 **/
Richard Burtonb40ecf82022-04-22 16:14:57 +010048 explicit KwsPreProcess(TfLiteTensor* inputTensor, size_t numFeatures, size_t numFeatureFrames,
49 int mfccFrameLength, int mfccFrameStride);
Richard Burtone6398cd2022-04-13 11:58:28 +010050
51 /**
52 * @brief Should perform pre-processing of 'raw' input audio data and load it into
53 * TFLite Micro input tensors ready for inference.
54 * @param[in] input Pointer to the data that pre-processing will work on.
55 * @param[in] inputSize Size of the input data.
56 * @return true if successful, false otherwise.
57 **/
58 bool DoPreProcess(const void* input, size_t inputSize) override;
59
60 size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */
61 size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */
62 size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */
63
64 private:
Richard Burtonb40ecf82022-04-22 16:14:57 +010065 TfLiteTensor* m_inputTensor; /* Model input tensor. */
Richard Burtone6398cd2022-04-13 11:58:28 +010066 const int m_mfccFrameLength;
67 const int m_mfccFrameStride;
Richard Burtonb40ecf82022-04-22 16:14:57 +010068 const size_t m_numMfccFrames; /* How many sets of m_numMfccFeats. */
Richard Burtone6398cd2022-04-13 11:58:28 +010069
70 audio::MicroNetKwsMFCC m_mfcc;
71 audio::SlidingWindow<const int16_t> m_mfccSlidingWindow;
72 size_t m_numMfccVectorsInAudioStride;
73 size_t m_numReusedMfccVectors;
74 std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator;
75
76 /**
77 * @brief Returns a function to perform feature calculation and populates input tensor data with
78 * MFCC data.
79 *
80 * Input tensor data type check is performed to choose correct MFCC feature data type.
81 * If tensor has an integer data type then original features are quantised.
82 *
83 * Warning: MFCC calculator provided as input must have the same life scope as returned function.
84 *
85 * @param[in] mfcc MFCC feature calculator.
86 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
87 * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
88 * @return Function to be called providing audio sample and sliding window index.
89 */
90 std::function<void (std::vector<int16_t>&, int, bool, size_t)>
91 GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
92 TfLiteTensor* inputTensor,
93 size_t cacheSize);
94
95 template<class T>
96 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
97 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
98 std::function<std::vector<T> (std::vector<int16_t>& )> compute);
99 };
100
101 /**
102 * @brief Post-processing class for Keyword Spotting use case.
103 * Implements methods declared by BasePostProcess and anything else needed
104 * to populate result vector.
105 */
Richard Burtonb40ecf82022-04-22 16:14:57 +0100106 class KwsPostProcess : public BasePostProcess {
Richard Burtone6398cd2022-04-13 11:58:28 +0100107
108 private:
Richard Burtonb40ecf82022-04-22 16:14:57 +0100109 TfLiteTensor* m_outputTensor; /* Model output tensor. */
110 Classifier& m_kwsClassifier; /* KWS Classifier object. */
111 const std::vector<std::string>& m_labels; /* KWS Labels. */
112 std::vector<ClassificationResult>& m_results; /* Results vector for a single inference. */
Richard Burtone6398cd2022-04-13 11:58:28 +0100113
114 public:
Richard Burtone6398cd2022-04-13 11:58:28 +0100115 /**
Richard Burtonc2911442022-04-22 09:08:21 +0100116 * @brief Constructor
Richard Burtonb40ecf82022-04-22 16:14:57 +0100117 * @param[in] outputTensor Pointer to the TFLite Micro output Tensor.
118 * @param[in] classifier Classifier object used to get top N results from classification.
119 * @param[in] labels Vector of string labels to identify each output of the model.
120 * @param[in/out] results Vector of classification results to store decoded outputs.
Richard Burtone6398cd2022-04-13 11:58:28 +0100121 **/
Richard Burtonb40ecf82022-04-22 16:14:57 +0100122 KwsPostProcess(TfLiteTensor* outputTensor, Classifier& classifier,
Richard Burtone6398cd2022-04-13 11:58:28 +0100123 const std::vector<std::string>& labels,
Richard Burtonc2911442022-04-22 09:08:21 +0100124 std::vector<ClassificationResult>& results);
Richard Burtone6398cd2022-04-13 11:58:28 +0100125
126 /**
Richard Burtonc2911442022-04-22 09:08:21 +0100127 * @brief Should perform post-processing of the result of inference then
128 * populate KWS result data for any later use.
129 * @return true if successful, false otherwise.
Richard Burtone6398cd2022-04-13 11:58:28 +0100130 **/
131 bool DoPostProcess() override;
132 };
133
134} /* namespace app */
135} /* namespace arm */
136
137#endif /* KWS_PROCESSING_HPP */