blob: 10efcd8ce7726f36a23086f6303f5c2383650646 [file] [log] [blame]
George Gekov23c26272021-08-16 11:32:10 +01001//
2// Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3// SPDX-License-Identifier: MIT
4//
5#include <iostream>
6#include <map>
7#include <vector>
8#include <algorithm>
9#include <cmath>
10#include "KeywordSpottingPipeline.hpp"
11#include "CmdArgsParser.hpp"
12#include "ArmnnNetworkExecutor.hpp"
13#include "AudioCapture.hpp"
14
15const std::string AUDIO_FILE_PATH = "--audio-file-path";
16const std::string MODEL_FILE_PATH = "--model-file-path";
17const std::string LABEL_PATH = "--label-path";
18const std::string PREFERRED_BACKENDS = "--preferred-backends";
19const std::string HELP = "--help";
20
21/*
22 * The accepted options for this Speech Recognition executable
23 */
24static std::map<std::string, std::string> CMD_OPTIONS =
25{
26 {AUDIO_FILE_PATH, "[REQUIRED] Path to the Audio file to run speech recognition on"},
27 {MODEL_FILE_PATH, "[REQUIRED] Path to the Speech Recognition model to use"},
28 {PREFERRED_BACKENDS, "[OPTIONAL] Takes the preferred backends in preference order, separated by comma."
29 " For example: CpuAcc,GpuAcc,CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]."
30 " Defaults to CpuAcc,CpuRef"}
31};
32
33/*
34 * Reads the user supplied backend preference, splits it by comma, and returns an ordered vector
35 */
36std::vector<armnn::BackendId> GetPreferredBackendList(const std::string& preferredBackends)
37{
38 std::vector<armnn::BackendId> backends;
39 std::stringstream ss(preferredBackends);
40
41 while (ss.good())
42 {
43 std::string backend;
44 std::getline(ss, backend, ',');
45 backends.emplace_back(backend);
46 }
47 return backends;
48}
49
50//Labels for this model
51std::map<int, std::string> labels =
52{
53 {0, "silence"},
54 {1, "unknown"},
55 {2, "yes"},
56 {3, "no"},
57 {4, "up"},
58 {5, "down"},
59 {6, "left"},
60 {7, "right"},
61 {8, "on"},
62 {9, "off"},
63 {10, "stop"},
64 {11, "go"}
65};
66
67
68int main(int argc, char* argv[])
69{
70 printf("ArmNN major version: %d\n", ARMNN_MAJOR_VERSION);
71 std::map<std::string, std::string> options;
72
73 //Read command line args
74 int result = ParseOptions(options, CMD_OPTIONS, argv, argc);
75 if (result != 0)
76 {
77 return result;
78 }
79
80 // Create the ArmNN inference runner
81 common::PipelineOptions pipelineOptions;
82 pipelineOptions.m_ModelName = "DS_CNN_CLUSTERED_INT8";
83 pipelineOptions.m_ModelFilePath = GetSpecifiedOption(options, MODEL_FILE_PATH);
84 if (CheckOptionSpecified(options, PREFERRED_BACKENDS))
85 {
86 pipelineOptions.m_backends = GetPreferredBackendList(
87 (GetSpecifiedOption(options, PREFERRED_BACKENDS)));
88 }
89 else
90 {
91 pipelineOptions.m_backends = {"CpuAcc", "CpuRef"};
92 }
93
94 kws::IPipelinePtr kwsPipeline = kws::CreatePipeline(pipelineOptions);
95
96 //Extract audio data from sound file
97 auto filePath = GetSpecifiedOption(options, AUDIO_FILE_PATH);
98 std::vector<float> audioData = audio::AudioCapture::LoadAudioFile(filePath);
99
100 audio::AudioCapture capture;
101 //todo: read samples and stride from pipeline
102 capture.InitSlidingWindow(audioData.data(),
103 audioData.size(),
104 kwsPipeline->getInputSamplesSize(),
105 kwsPipeline->getInputSamplesSize()/2);
106
107 //Loop through audio data buffer
108 while (capture.HasNext())
109 {
110 std::vector<float> audioBlock = capture.Next();
111 common::InferenceResults<int8_t> results;
112
113 //Prepare input tensors
114 std::vector<int8_t> preprocessedData = kwsPipeline->PreProcessing(audioBlock);
115 //Run inference
116 kwsPipeline->Inference(preprocessedData, results);
117 //Decode output
118 kwsPipeline->PostProcessing(results, labels,
119 [](int index, std::string& label, float prob) -> void {
120 printf("Keyword \"%s\", index %d:, probability %f\n",
121 label.c_str(),
122 index,
123 prob);
124 });
125 }
126
127 return 0;
128}