Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | |
| 6 | #pragma once |
| 7 | |
| 8 | #include "ArmnnNetworkExecutor.hpp" |
| 9 | #include "Decoder.hpp" |
| 10 | #include "MFCC.hpp" |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 11 | #include "Wav2LetterPreprocessor.hpp" |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 12 | |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 13 | namespace asr |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 14 | { |
| 15 | /** |
| 16 | * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference |
| 17 | * result post-processing. |
| 18 | * |
| 19 | */ |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 20 | class ASRPipeline |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 21 | { |
| 22 | public: |
| 23 | |
| 24 | /** |
| 25 | * Creates speech recognition pipeline with given network executor and decoder. |
| 26 | * @param executor - unique pointer to inference runner |
| 27 | * @param decoder - unique pointer to inference results decoder |
| 28 | */ |
| 29 | ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor, |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 30 | std::unique_ptr<Decoder> decoder, std::unique_ptr<Wav2LetterPreprocessor> preprocessor); |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 31 | |
| 32 | /** |
| 33 | * @brief Standard audio pre-processing implementation. |
| 34 | * |
| 35 | * Preprocesses and prepares the data for inference by |
| 36 | * extracting the MFCC features. |
| 37 | |
| 38 | * @param[in] audio - the raw audio data |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 39 | * @param[out] preprocessor - the preprocessor object, which handles the data preparation |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 40 | */ |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 41 | std::vector<int8_t> PreProcessing(std::vector<float>& audio); |
| 42 | |
| 43 | int getInputSamplesSize(); |
| 44 | int getSlidingWindowOffset(); |
| 45 | |
| 46 | // Exposing hardcoded constant as it can only be derived from model knowledge and not from model itself |
| 47 | // Will need to be refactored so that hard coded values are not defined outside of model settings |
| 48 | int SLIDING_WINDOW_OFFSET; |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 49 | |
| 50 | /** |
| 51 | * @brief Executes inference |
| 52 | * |
| 53 | * Calls inference runner provided during instance construction. |
| 54 | * |
| 55 | * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor. |
| 56 | * @param[out] result - raw inference results. |
| 57 | */ |
| 58 | template<typename T> |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 59 | void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result) |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 60 | { |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 61 | size_t data_bytes = sizeof(T) * preprocessedData.size(); |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 62 | m_executor->Run(preprocessedData.data(), data_bytes, result); |
| 63 | } |
| 64 | |
| 65 | /** |
| 66 | * @brief Standard inference results post-processing implementation. |
| 67 | * |
| 68 | * Decodes inference results using decoder provided during construction. |
| 69 | * |
| 70 | * @param[in] inferenceResult - inference results to be decoded. |
| 71 | * @param[in] isFirstWindow - for checking if this is the first window of the sliding window. |
| 72 | * @param[in] isLastWindow - for checking if this is the last window of the sliding window. |
| 73 | * @param[in] currentRContext - the right context of the output text. To be output if it is the last window. |
| 74 | */ |
| 75 | template<typename T> |
| 76 | void PostProcessing(common::InferenceResults<int8_t>& inferenceResult, |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 77 | bool& isFirstWindow, |
| 78 | bool isLastWindow, |
| 79 | std::string currentRContext) |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 80 | { |
| 81 | int rowLength = 29; |
| 82 | int middleContextStart = 49; |
| 83 | int middleContextEnd = 99; |
| 84 | int leftContextStart = 0; |
| 85 | int rightContextStart = 100; |
| 86 | int rightContextEnd = 148; |
| 87 | |
| 88 | std::vector<T> contextToProcess; |
| 89 | |
| 90 | // If isFirstWindow we keep the left context of the output |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 91 | if (isFirstWindow) |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 92 | { |
| 93 | std::vector<T> chunk(&inferenceResult[0][leftContextStart], |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 94 | &inferenceResult[0][middleContextEnd * rowLength]); |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 95 | contextToProcess = chunk; |
| 96 | } |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 97 | else |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 98 | { |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 99 | // Else we only keep the middle context of the output |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 100 | std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength], |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 101 | &inferenceResult[0][middleContextEnd * rowLength]); |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 102 | contextToProcess = chunk; |
| 103 | } |
| 104 | std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess); |
| 105 | isFirstWindow = false; |
| 106 | std::cout << output << std::flush; |
| 107 | |
| 108 | // If this is the last window, we print the right context of the output |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 109 | if (isLastWindow) |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 110 | { |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 111 | std::vector<T> rContext(&inferenceResult[0][rightContextStart * rowLength], |
| 112 | &inferenceResult[0][rightContextEnd * rowLength]); |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 113 | currentRContext = this->m_decoder->DecodeOutput(rContext); |
| 114 | std::cout << currentRContext << std::endl; |
| 115 | } |
| 116 | } |
| 117 | |
| 118 | protected: |
| 119 | std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor; |
| 120 | std::unique_ptr<Decoder> m_decoder; |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 121 | std::unique_ptr<Wav2LetterPreprocessor> m_preProcessor; |
Éanna Ó Catháin | c6ab02a | 2021-04-07 14:35:25 +0100 | [diff] [blame] | 122 | }; |
| 123 | |
| 124 | using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>; |
| 125 | |
| 126 | /** |
| 127 | * Constructs speech recognition pipeline based on configuration provided. |
| 128 | * |
| 129 | * @param[in] config - speech recognition pipeline configuration. |
| 130 | * @param[in] labels - asr labels |
| 131 | * |
| 132 | * @return unique pointer to asr pipeline. |
| 133 | */ |
| 134 | IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels); |
| 135 | |
George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 136 | } // namespace asr |