blob: bc3fbfe1516a547f9e956756d03c11c3339eb824 [file] [log] [blame]
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnNetworkExecutor.hpp"
9#include "Decoder.hpp"
10#include "MFCC.hpp"
George Gekov23c26272021-08-16 11:32:10 +010011#include "Wav2LetterPreprocessor.hpp"
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010012
George Gekov23c26272021-08-16 11:32:10 +010013namespace asr
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010014{
15/**
16 * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference
17 * result post-processing.
18 *
19 */
George Gekov23c26272021-08-16 11:32:10 +010020class ASRPipeline
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010021{
22public:
23
24 /**
25 * Creates speech recognition pipeline with given network executor and decoder.
26 * @param executor - unique pointer to inference runner
27 * @param decoder - unique pointer to inference results decoder
28 */
29 ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
George Gekov23c26272021-08-16 11:32:10 +010030 std::unique_ptr<Decoder> decoder, std::unique_ptr<Wav2LetterPreprocessor> preprocessor);
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010031
32 /**
33 * @brief Standard audio pre-processing implementation.
34 *
35 * Preprocesses and prepares the data for inference by
36 * extracting the MFCC features.
37
38 * @param[in] audio - the raw audio data
George Gekov23c26272021-08-16 11:32:10 +010039 * @param[out] preprocessor - the preprocessor object, which handles the data preparation
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010040 */
George Gekov23c26272021-08-16 11:32:10 +010041 std::vector<int8_t> PreProcessing(std::vector<float>& audio);
42
43 int getInputSamplesSize();
44 int getSlidingWindowOffset();
45
46 // Exposing hardcoded constant as it can only be derived from model knowledge and not from model itself
47 // Will need to be refactored so that hard coded values are not defined outside of model settings
48 int SLIDING_WINDOW_OFFSET;
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010049
50 /**
51 * @brief Executes inference
52 *
53 * Calls inference runner provided during instance construction.
54 *
55 * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor.
56 * @param[out] result - raw inference results.
57 */
58 template<typename T>
George Gekov23c26272021-08-16 11:32:10 +010059 void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result)
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010060 {
George Gekov23c26272021-08-16 11:32:10 +010061 size_t data_bytes = sizeof(T) * preprocessedData.size();
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010062 m_executor->Run(preprocessedData.data(), data_bytes, result);
63 }
64
65 /**
66 * @brief Standard inference results post-processing implementation.
67 *
68 * Decodes inference results using decoder provided during construction.
69 *
70 * @param[in] inferenceResult - inference results to be decoded.
71 * @param[in] isFirstWindow - for checking if this is the first window of the sliding window.
72 * @param[in] isLastWindow - for checking if this is the last window of the sliding window.
73 * @param[in] currentRContext - the right context of the output text. To be output if it is the last window.
74 */
75 template<typename T>
76 void PostProcessing(common::InferenceResults<int8_t>& inferenceResult,
George Gekov23c26272021-08-16 11:32:10 +010077 bool& isFirstWindow,
78 bool isLastWindow,
79 std::string currentRContext)
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010080 {
81 int rowLength = 29;
82 int middleContextStart = 49;
83 int middleContextEnd = 99;
84 int leftContextStart = 0;
85 int rightContextStart = 100;
86 int rightContextEnd = 148;
87
88 std::vector<T> contextToProcess;
89
90 // If isFirstWindow we keep the left context of the output
George Gekov23c26272021-08-16 11:32:10 +010091 if (isFirstWindow)
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010092 {
93 std::vector<T> chunk(&inferenceResult[0][leftContextStart],
George Gekov23c26272021-08-16 11:32:10 +010094 &inferenceResult[0][middleContextEnd * rowLength]);
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010095 contextToProcess = chunk;
96 }
George Gekov23c26272021-08-16 11:32:10 +010097 else
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +010098 {
George Gekov23c26272021-08-16 11:32:10 +010099 // Else we only keep the middle context of the output
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +0100100 std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength],
George Gekov23c26272021-08-16 11:32:10 +0100101 &inferenceResult[0][middleContextEnd * rowLength]);
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +0100102 contextToProcess = chunk;
103 }
104 std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess);
105 isFirstWindow = false;
106 std::cout << output << std::flush;
107
108 // If this is the last window, we print the right context of the output
George Gekov23c26272021-08-16 11:32:10 +0100109 if (isLastWindow)
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +0100110 {
George Gekov23c26272021-08-16 11:32:10 +0100111 std::vector<T> rContext(&inferenceResult[0][rightContextStart * rowLength],
112 &inferenceResult[0][rightContextEnd * rowLength]);
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +0100113 currentRContext = this->m_decoder->DecodeOutput(rContext);
114 std::cout << currentRContext << std::endl;
115 }
116 }
117
118protected:
119 std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor;
120 std::unique_ptr<Decoder> m_decoder;
George Gekov23c26272021-08-16 11:32:10 +0100121 std::unique_ptr<Wav2LetterPreprocessor> m_preProcessor;
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +0100122};
123
124using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>;
125
126/**
127 * Constructs speech recognition pipeline based on configuration provided.
128 *
129 * @param[in] config - speech recognition pipeline configuration.
130 * @param[in] labels - asr labels
131 *
132 * @return unique pointer to asr pipeline.
133 */
134IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels);
135
George Gekov23c26272021-08-16 11:32:10 +0100136} // namespace asr