blob: 47ce30416f4f4433d8bf5ba15c8f239146e6c7ab [file] [log] [blame]
//
// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include "ArmnnNetworkExecutor.hpp"
#include "Decoder.hpp"
#include "MFCC.hpp"
#include "Preprocess.hpp"
namespace asr
{
/**
* Generic Speech Recognition pipeline with 3 steps: data pre-processing, inference execution and inference
* result post-processing.
*
*/
class ASRPipeline
{
public:
/**
* Creates speech recognition pipeline with given network executor and decoder.
* @param executor - unique pointer to inference runner
* @param decoder - unique pointer to inference results decoder
*/
ASRPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
std::unique_ptr<Decoder> decoder);
/**
* @brief Standard audio pre-processing implementation.
*
* Preprocesses and prepares the data for inference by
* extracting the MFCC features.
* @param[in] audio - the raw audio data
* @param[out] preprocessor - the preprocessor object, which handles the data prepreration
*/
template<typename Tin,typename Tout>
std::vector<Tout> PreProcessing(std::vector<Tin>& audio, Preprocess& preprocessor)
{
int audioDataToPreProcess = preprocessor._m_windowLen +
((preprocessor._m_mfcc._m_params.m_numMfccVectors -1) *preprocessor._m_windowStride);
int outputBufferSize = preprocessor._m_mfcc._m_params.m_numMfccVectors
* preprocessor._m_mfcc._m_params.m_numMfccFeatures * 3;
std::vector<Tout> outputBuffer(outputBufferSize);
preprocessor.Invoke(audio.data(), audioDataToPreProcess, outputBuffer, m_executor->GetQuantizationOffset(),
m_executor->GetQuantizationScale());
return outputBuffer;
}
/**
* @brief Executes inference
*
* Calls inference runner provided during instance construction.
*
* @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor.
* @param[out] result - raw inference results.
*/
template<typename T>
void Inference(const std::vector<T>& preprocessedData, common::InferenceResults<int8_t>& result)
{
size_t data_bytes = sizeof(std::vector<T>) + (sizeof(T) * preprocessedData.size());
m_executor->Run(preprocessedData.data(), data_bytes, result);
}
/**
* @brief Standard inference results post-processing implementation.
*
* Decodes inference results using decoder provided during construction.
*
* @param[in] inferenceResult - inference results to be decoded.
* @param[in] isFirstWindow - for checking if this is the first window of the sliding window.
* @param[in] isLastWindow - for checking if this is the last window of the sliding window.
* @param[in] currentRContext - the right context of the output text. To be output if it is the last window.
*/
template<typename T>
void PostProcessing(common::InferenceResults<int8_t>& inferenceResult,
bool& isFirstWindow,
bool isLastWindow,
std::string currentRContext)
{
int rowLength = 29;
int middleContextStart = 49;
int middleContextEnd = 99;
int leftContextStart = 0;
int rightContextStart = 100;
int rightContextEnd = 148;
std::vector<T> contextToProcess;
// If isFirstWindow we keep the left context of the output
if(isFirstWindow)
{
std::vector<T> chunk(&inferenceResult[0][leftContextStart],
&inferenceResult[0][middleContextEnd * rowLength]);
contextToProcess = chunk;
}
// Else we only keep the middle context of the output
else
{
std::vector<T> chunk(&inferenceResult[0][middleContextStart * rowLength],
&inferenceResult[0][middleContextEnd * rowLength]);
contextToProcess = chunk;
}
std::string output = this->m_decoder->DecodeOutput<T>(contextToProcess);
isFirstWindow = false;
std::cout << output << std::flush;
// If this is the last window, we print the right context of the output
if(isLastWindow)
{
std::vector<T> rContext(&inferenceResult[0][rightContextStart*rowLength],
&inferenceResult[0][rightContextEnd * rowLength]);
currentRContext = this->m_decoder->DecodeOutput(rContext);
std::cout << currentRContext << std::endl;
}
}
protected:
std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor;
std::unique_ptr<Decoder> m_decoder;
};
using IPipelinePtr = std::unique_ptr<asr::ASRPipeline>;
/**
* Constructs speech recognition pipeline based on configuration provided.
*
* @param[in] config - speech recognition pipeline configuration.
* @param[in] labels - asr labels
*
* @return unique pointer to asr pipeline.
*/
IPipelinePtr CreatePipeline(common::PipelineOptions& config, std::map<int, std::string>& labels);
}// namespace asr