blob: e32d9476e3805158d71e85dfd25ae57fe72df309 [file] [log] [blame]
George Gekov23c26272021-08-16 11:32:10 +01001//
2// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5
6#include "KeywordSpottingPipeline.hpp"
7#include "ArmnnNetworkExecutor.hpp"
8#include "DsCNNPreprocessor.hpp"
9
10namespace kws
11{
12KWSPipeline::KWSPipeline(std::unique_ptr<common::ArmnnNetworkExecutor<int8_t>> executor,
13 std::unique_ptr<Decoder> decoder,
14 std::unique_ptr<DsCNNPreprocessor> preProcessor
15 ) :
16 m_executor(std::move(executor)),
17 m_decoder(std::move(decoder)),
18 m_preProcessor(std::move(preProcessor)) {}
19
20
21std::vector<int8_t> KWSPipeline::PreProcessing(std::vector<float>& audio)
22{
23 return m_preProcessor->Invoke(audio.data(), audio.size(), m_executor->GetQuantizationOffset(),
24 m_executor->GetQuantizationScale());
25}
26
27void KWSPipeline::Inference(const std::vector<int8_t>& preprocessedData,
28 common::InferenceResults<int8_t>& result)
29{
30 m_executor->Run(preprocessedData.data(), preprocessedData.size(), result);
31}
32
33void KWSPipeline::PostProcessing(common::InferenceResults<int8_t>& inferenceResults,
34 std::map<int, std::string>& labels,
35 const std::function<void (int, std::string&, float)>& callback)
36{
37 std::pair<int,float> outputDecoder = this->m_decoder->decodeOutput(inferenceResults[0]);
38 int keywordIndex = std::get<0>(outputDecoder);
39 std::string output = labels[keywordIndex];
40 callback(keywordIndex, output, std::get<1>(outputDecoder));
41}
42
43int KWSPipeline::getInputSamplesSize()
44{
45 return this->m_preProcessor->m_windowLen +
46 ((this->m_preProcessor->m_mfcc->m_params.m_numMfccVectors - 1) *
47 this->m_preProcessor->m_windowStride);
48}
49
50IPipelinePtr CreatePipeline(common::PipelineOptions& config)
51{
52 if (config.m_ModelName == "DS_CNN_CLUSTERED_INT8")
53 {
54 //DS-CNN model settings
55 float SAMP_FREQ = 16000;
56 int MFCC_WINDOW_LEN = 640;
57 int MFCC_WINDOW_STRIDE = 320;
58 int NUM_MFCC_FEATS = 10;
59 int NUM_MFCC_VECTORS = 49;
60 //todo: calc in pipeline and use in main
61 int SAMPLES_PER_INFERENCE = NUM_MFCC_VECTORS * MFCC_WINDOW_STRIDE +
62 MFCC_WINDOW_LEN - MFCC_WINDOW_STRIDE; //16000
63 float MEL_LO_FREQ = 20;
64 float MEL_HI_FREQ = 4000;
65 int NUM_FBANK_BIN = 40;
66
67 MfccParams mfccParams(SAMP_FREQ,
68 NUM_FBANK_BIN,
69 MEL_LO_FREQ,
70 MEL_HI_FREQ,
71 NUM_MFCC_FEATS,
72 MFCC_WINDOW_LEN, false,
73 NUM_MFCC_VECTORS);
74
75 std::unique_ptr<DsCnnMFCC> mfccInst = std::make_unique<DsCnnMFCC>(mfccParams);
76 auto preprocessor = std::make_unique<kws::DsCNNPreprocessor>(
77 MFCC_WINDOW_LEN, MFCC_WINDOW_STRIDE, std::move(mfccInst));
78
79 auto executor = std::make_unique<common::ArmnnNetworkExecutor<int8_t>>(
80 config.m_ModelFilePath, config.m_backends);
81
82 auto decoder = std::make_unique<kws::Decoder>(executor->GetOutputQuantizationOffset(0),
83 executor->GetOutputQuantizationScale(0));
84
85 return std::make_unique<kws::KWSPipeline>(std::move(executor),
86 std::move(decoder), std::move(preprocessor));
87 }
88 else
89 {
90 throw std::invalid_argument("Unknown Model name: " + config.m_ModelName + " .");
91 }
92}
93
94};// namespace kws