blob: 47ce30416f4f4433d8bf5ba15c8f239146e6c7ab [file] [log] [blame]
Éanna Ó Catháinc6ab02a2021-04-07 14:35:25 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnNetworkExecutor.hpp"
9#include "Decoder.hpp"
10#include "MFCC.hpp"
11#include "Preprocess.hpp"
12
13namespace asr
14{
15/**
16 * Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference
17 * result post-processing.
18 *
19 */
20class ASRPipeline
21{
22public:
23
24 /**
25 * Creates speech recognition pipeline with given network executor and decoder.
26 * @param executor - unique pointer to inference runner
27 * @param decoder - unique pointer to inference results decoder
28 */
29 ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
30 std::unique_ptr<Decoder> decoder);
31
32 /**
33 * @brief Standard audio pre-processing implementation.
34 *
35 * Preprocesses and prepares the data for inference by
36 * extracting the MFCC features.
37
38 * @param[in] audio - the raw audio data
39 * @param[out] preprocessor - the preprocessor object, which handles the data prepreration
40 */
41 template<typename Tin,typename Tout>
42 std::vector<Tout> PreProcessing(std::vector<Tin>& audio, Preprocess& preprocessor)
43 {
44 int audioDataToPreProcess = preprocessor._m_windowLen +
45 ((preprocessor._m_mfcc._m_params.m_numMfccVectors -1) *preprocessor._m_windowStride);
46 int outputBufferSize = preprocessor._m_mfcc._m_params.m_numMfccVectors
47 * preprocessor._m_mfcc._m_params.m_numMfccFeatures * 3;
48 std::vector<Tout> outputBuffer(outputBufferSize);
49 preprocessor.Invoke(audio.data(), audioDataToPreProcess, outputBuffer, m_executor->GetQuantizationOffset(),
50 m_executor->GetQuantizationScale());
51 return outputBuffer;
52 }
53
54 /**
55 * @brief Executes inference
56 *
57 * Calls inference runner provided during instance construction.
58 *
59 * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor.
60 * @param[out] result - raw inference results.
61 */
62 template<typename T>
63 void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result)
64 {
65 size_t data_bytes = sizeof(std::vector<T>) + (sizeof(T) * preprocessedData.size());
66 m_executor->Run(preprocessedData.data(), data_bytes, result);
67 }
68
69 /**
70 * @brief Standard inference results post-processing implementation.
71 *
72 * Decodes inference results using decoder provided during construction.
73 *
74 * @param[in] inferenceResult - inference results to be decoded.
75 * @param[in] isFirstWindow - for checking if this is the first window of the sliding window.
76 * @param[in] isLastWindow - for checking if this is the last window of the sliding window.
77 * @param[in] currentRContext - the right context of the output text. To be output if it is the last window.
78 */
79 template<typename T>
80 void PostProcessing(common::InferenceResults<int8_t>& inferenceResult,
81 bool& isFirstWindow,
82 bool isLastWindow,
83 std::string currentRContext)
84 {
85 int rowLength = 29;
86 int middleContextStart = 49;
87 int middleContextEnd = 99;
88 int leftContextStart = 0;
89 int rightContextStart = 100;
90 int rightContextEnd = 148;
91
92 std::vector<T> contextToProcess;
93
94 // If isFirstWindow we keep the left context of the output
95 if(isFirstWindow)
96 {
97 std::vector<T> chunk(&inferenceResult[0][leftContextStart],
98 &inferenceResult[0][middleContextEnd * rowLength]);
99 contextToProcess = chunk;
100 }
101 // Else we only keep the middle context of the output
102 else
103 {
104 std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength],
105 &inferenceResult[0][middleContextEnd * rowLength]);
106 contextToProcess = chunk;
107 }
108 std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess);
109 isFirstWindow = false;
110 std::cout << output << std::flush;
111
112 // If this is the last window, we print the right context of the output
113 if(isLastWindow)
114 {
115 std::vector<T> rContext(&inferenceResult[0][rightContextStart*rowLength],
116 &inferenceResult[0][rightContextEnd * rowLength]);
117 currentRContext = this->m_decoder->DecodeOutput(rContext);
118 std::cout << currentRContext << std::endl;
119 }
120 }
121
122protected:
123 std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor;
124 std::unique_ptr<Decoder> m_decoder;
125};
126
127using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>;
128
129/**
130 * Constructs speech recognition pipeline based on configuration provided.
131 *
132 * @param[in] config - speech recognition pipeline configuration.
133 * @param[in] labels - asr labels
134 *
135 * @return unique pointer to asr pipeline.
136 */
137IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels);
138
139}// namespace asr