blob: bd47987a598f57214e21a22c82cccf1b5aa4b641 [file] [log] [blame]
George Gekov23c26272021-08-16 11:32:10 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#pragma once
7
8#include "ArmnnNetworkExecutor.hpp"
9#include "Decoder.hpp"
10#include "MFCC.hpp"
11#include "DsCNNPreprocessor.hpp"
12
13namespace kws
14{
15/**
16 * Generic Keyword Spotting pipeline with 3 steps: data pre-processing, inference execution and inference
17 * result post-processing.
18 *
19 */
20class KWSPipeline
21{
22public:
23
24 /**
25 * Creates speech recognition pipeline with given network executor and decoder.
26 * @param executor - unique pointer to inference runner
27 * @param decoder - unique pointer to inference results decoder
28 */
29 KWSPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
30 std::unique_ptr<Decoder> decoder,
31 std::unique_ptr<DsCNNPreprocessor> preProcessor);
32
33 /**
34 * @brief Standard audio pre-processing implementation.
35 *
36 * Preprocesses and prepares the data for inference by
37 * extracting the MFCC features.
38
39 * @param[in] audio - the raw audio data
40 */
41
42 std::vector<int8_t> PreProcessing(std::vector<float>& audio);
43
44 /**
45 * @brief Executes inference
46 *
47 * Calls inference runner provided during instance construction.
48 *
49 * @param[in] preprocessedData - input inference data. Data type should be aligned with input tensor.
50 * @param[out] result - raw inference results.
51 */
52 void Inference(const std::vector<int8_t>& preprocessedData, common::InferenceResults<int8_t>& result);
53
54 /**
55 * @brief Standard inference results post-processing implementation.
56 *
57 * Decodes inference results using decoder provided during construction.
58 *
59 * @param[in] inferenceResult - inference results to be decoded.
60 * @param[in] labels - the words we use for the model
61 */
62 void PostProcessing(common::InferenceResults<int8_t>& inferenceResults,
63 std::map<int, std::string>& labels,
64 const std::function<void (int, std::string&, float)>& callback);
65
66 /**
67 * @brief Get the number of samples for the pipeline input
68
69 * @return - number of samples for the pipeline
70 */
71 int getInputSamplesSize();
72
73protected:
74 std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> m_executor;
75 std::unique_ptr<Decoder> m_decoder;
76 std::unique_ptr<DsCNNPreprocessor> m_preProcessor;
77};
78
79using IPipelinePtr = std::unique_ptr<kws::KWSPipeline>;
80
81/**
82 * Constructs speech recognition pipeline based on configuration provided.
83 *
84 * @param[in] config - speech recognition pipeline configuration.
85 * @param[in] labels - asr labels
86 *
87 * @return unique pointer to asr pipeline.
88 */
89IPipelinePtr CreatePipeline(common::PipelineOptions& config);
90
91};// namespace kws