blob: abf20ab8a8aa0bec2070eec2ad1e8d0089100c43 [file] [log] [blame]
Richard Burtone6398cd2022-04-13 11:58:28 +01001/*
2 * Copyright (c) 2022 Arm Limited. All rights reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17#ifndef KWS_PROCESSING_HPP
18#define KWS_PROCESSING_HPP
19
20#include <AudioUtils.hpp>
21#include "BaseProcessing.hpp"
22#include "Model.hpp"
23#include "Classifier.hpp"
24#include "MicroNetKwsMfcc.hpp"
25
26#include <functional>
27
28namespace arm {
29namespace app {
30
31 /**
32 * @brief Pre-processing class for Keyword Spotting use case.
33 * Implements methods declared by BasePreProcess and anything else needed
34 * to populate input tensors ready for inference.
35 */
36 class KWSPreProcess : public BasePreProcess {
37
38 public:
39 /**
40 * @brief Constructor
41 * @param[in] model Pointer to the the KWS Model object.
42 * @param[in] numFeatures How many MFCC features to use.
43 * @param[in] mfccFrameLength Number of audio samples used to calculate one set of MFCC values when
44 * sliding a window through the audio sample.
45 * @param[in] mfccFrameStride Number of audio samples between consecutive windows.
46 **/
47 explicit KWSPreProcess(Model* model, size_t numFeatures, int mfccFrameLength, int mfccFrameStride);
48
49 /**
50 * @brief Should perform pre-processing of 'raw' input audio data and load it into
51 * TFLite Micro input tensors ready for inference.
52 * @param[in] input Pointer to the data that pre-processing will work on.
53 * @param[in] inputSize Size of the input data.
54 * @return true if successful, false otherwise.
55 **/
56 bool DoPreProcess(const void* input, size_t inputSize) override;
57
58 size_t m_audioWindowIndex = 0; /* Index of audio slider, used when caching features in longer clips. */
59 size_t m_audioDataWindowSize; /* Amount of audio needed for 1 inference. */
60 size_t m_audioDataStride; /* Amount of audio to stride across if doing >1 inference in longer clips. */
61
62 private:
63 const int m_mfccFrameLength;
64 const int m_mfccFrameStride;
65
66 audio::MicroNetKwsMFCC m_mfcc;
67 audio::SlidingWindow<const int16_t> m_mfccSlidingWindow;
68 size_t m_numMfccVectorsInAudioStride;
69 size_t m_numReusedMfccVectors;
70 std::function<void (std::vector<int16_t>&, int, bool, size_t)> m_mfccFeatureCalculator;
71
72 /**
73 * @brief Returns a function to perform feature calculation and populates input tensor data with
74 * MFCC data.
75 *
76 * Input tensor data type check is performed to choose correct MFCC feature data type.
77 * If tensor has an integer data type then original features are quantised.
78 *
79 * Warning: MFCC calculator provided as input must have the same life scope as returned function.
80 *
81 * @param[in] mfcc MFCC feature calculator.
82 * @param[in,out] inputTensor Input tensor pointer to store calculated features.
83 * @param[in] cacheSize Size of the feature vectors cache (number of feature vectors).
84 * @return Function to be called providing audio sample and sliding window index.
85 */
86 std::function<void (std::vector<int16_t>&, int, bool, size_t)>
87 GetFeatureCalculator(audio::MicroNetKwsMFCC& mfcc,
88 TfLiteTensor* inputTensor,
89 size_t cacheSize);
90
91 template<class T>
92 std::function<void (std::vector<int16_t>&, size_t, bool, size_t)>
93 FeatureCalc(TfLiteTensor* inputTensor, size_t cacheSize,
94 std::function<std::vector<T> (std::vector<int16_t>& )> compute);
95 };
96
97 /**
98 * @brief Post-processing class for Keyword Spotting use case.
99 * Implements methods declared by BasePostProcess and anything else needed
100 * to populate result vector.
101 */
102 class KWSPostProcess : public BasePostProcess {
103
104 private:
105 Classifier& m_kwsClassifier;
106 const std::vector<std::string>& m_labels;
107 std::vector<ClassificationResult>& m_results;
108
109 public:
110 const float m_scoreThreshold;
111 /**
112 * @brief Constructor
113 * @param[in] classifier Classifier object used to get top N results from classification.
114 * @param[in] model Pointer to the the Image classification Model object.
115 * @param[in] labels Vector of string labels to identify each output of the model.
116 * @param[in] results Vector of classification results to store decoded outputs.
117 * @param[in] scoreThreshold Predicted model score must be larger than this value to be accepted.
118 **/
119 KWSPostProcess(Classifier& classifier, Model* model,
120 const std::vector<std::string>& labels,
121 std::vector<ClassificationResult>& results,
122 float scoreThreshold);
123
124 /**
125 * @brief Should perform post-processing of the result of inference then populate
126 * populate KWS result data for any later use.
127 * @return true if successful, false otherwise.
128 **/
129 bool DoPostProcess() override;
130 };
131
132} /* namespace app */
133} /* namespace arm */
134
135#endif /* KWS_PROCESSING_HPP */