George Gekov | 23c2627 | 2021-08-16 11:32:10 +0100 | [diff] [blame] | 1 | // |
| 2 | // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. |
| 3 | // SPDX-License-Identifier: MIT |
| 4 | // |
| 5 | #include <iostream> |
| 6 | #include <map> |
| 7 | #include <vector> |
| 8 | #include <algorithm> |
| 9 | #include <cmath> |
| 10 | #include "KeywordSpottingPipeline.hpp" |
| 11 | #include "CmdArgsParser.hpp" |
| 12 | #include "ArmnnNetworkExecutor.hpp" |
| 13 | #include "AudioCapture.hpp" |
| 14 | |
| 15 | const std::string AUDIO_FILE_PATH = "--audio-file-path"; |
| 16 | const std::string MODEL_FILE_PATH = "--model-file-path"; |
| 17 | const std::string LABEL_PATH = "--label-path"; |
| 18 | const std::string PREFERRED_BACKENDS = "--preferred-backends"; |
| 19 | const std::string HELP = "--help"; |
| 20 | |
| 21 | /* |
| 22 | * The accepted options for this Speech Recognition executable |
| 23 | */ |
| 24 | static std::map<std::string, std::string> CMD_OPTIONS = |
| 25 | { |
| 26 | {AUDIO_FILE_PATH, "[REQUIRED] Path to the Audio file to run speech recognition on"}, |
| 27 | {MODEL_FILE_PATH, "[REQUIRED] Path to the Speech Recognition model to use"}, |
| 28 | {PREFERRED_BACKENDS, "[OPTIONAL] Takes the preferred backends in preference order, separated by comma." |
| 29 | " For example: CpuAcc,GpuAcc,CpuRef. Accepted options: [CpuAcc, CpuRef, GpuAcc]." |
| 30 | " Defaults to CpuAcc,CpuRef"} |
| 31 | }; |
| 32 | |
| 33 | /* |
| 34 | * Reads the user supplied backend preference, splits it by comma, and returns an ordered vector |
| 35 | */ |
| 36 | std::vector<armnn::BackendId> GetPreferredBackendList(const std::string& preferredBackends) |
| 37 | { |
| 38 | std::vector<armnn::BackendId> backends; |
| 39 | std::stringstream ss(preferredBackends); |
| 40 | |
| 41 | while (ss.good()) |
| 42 | { |
| 43 | std::string backend; |
| 44 | std::getline(ss, backend, ','); |
| 45 | backends.emplace_back(backend); |
| 46 | } |
| 47 | return backends; |
| 48 | } |
| 49 | |
| 50 | //Labels for this model |
| 51 | std::map<int, std::string> labels = |
| 52 | { |
| 53 | {0, "silence"}, |
| 54 | {1, "unknown"}, |
| 55 | {2, "yes"}, |
| 56 | {3, "no"}, |
| 57 | {4, "up"}, |
| 58 | {5, "down"}, |
| 59 | {6, "left"}, |
| 60 | {7, "right"}, |
| 61 | {8, "on"}, |
| 62 | {9, "off"}, |
| 63 | {10, "stop"}, |
| 64 | {11, "go"} |
| 65 | }; |
| 66 | |
| 67 | |
| 68 | int main(int argc, char* argv[]) |
| 69 | { |
| 70 | printf("ArmNN major version: %d\n", ARMNN_MAJOR_VERSION); |
| 71 | std::map<std::string, std::string> options; |
| 72 | |
| 73 | //Read command line args |
| 74 | int result = ParseOptions(options, CMD_OPTIONS, argv, argc); |
| 75 | if (result != 0) |
| 76 | { |
| 77 | return result; |
| 78 | } |
| 79 | |
| 80 | // Create the ArmNN inference runner |
| 81 | common::PipelineOptions pipelineOptions; |
| 82 | pipelineOptions.m_ModelName = "DS_CNN_CLUSTERED_INT8"; |
| 83 | pipelineOptions.m_ModelFilePath = GetSpecifiedOption(options, MODEL_FILE_PATH); |
| 84 | if (CheckOptionSpecified(options, PREFERRED_BACKENDS)) |
| 85 | { |
| 86 | pipelineOptions.m_backends = GetPreferredBackendList( |
| 87 | (GetSpecifiedOption(options, PREFERRED_BACKENDS))); |
| 88 | } |
| 89 | else |
| 90 | { |
| 91 | pipelineOptions.m_backends = {"CpuAcc", "CpuRef"}; |
| 92 | } |
| 93 | |
| 94 | kws::IPipelinePtr kwsPipeline = kws::CreatePipeline(pipelineOptions); |
| 95 | |
| 96 | //Extract audio data from sound file |
| 97 | auto filePath = GetSpecifiedOption(options, AUDIO_FILE_PATH); |
| 98 | std::vector<float> audioData = audio::AudioCapture::LoadAudioFile(filePath); |
| 99 | |
| 100 | audio::AudioCapture capture; |
| 101 | //todo: read samples and stride from pipeline |
| 102 | capture.InitSlidingWindow(audioData.data(), |
| 103 | audioData.size(), |
| 104 | kwsPipeline->getInputSamplesSize(), |
| 105 | kwsPipeline->getInputSamplesSize()/2); |
| 106 | |
| 107 | //Loop through audio data buffer |
| 108 | while (capture.HasNext()) |
| 109 | { |
| 110 | std::vector<float> audioBlock = capture.Next(); |
| 111 | common::InferenceResults<int8_t> results; |
| 112 | |
| 113 | //Prepare input tensors |
| 114 | std::vector<int8_t> preprocessedData = kwsPipeline->PreProcessing(audioBlock); |
| 115 | //Run inference |
| 116 | kwsPipeline->Inference(preprocessedData, results); |
| 117 | //Decode output |
| 118 | kwsPipeline->PostProcessing(results, labels, |
| 119 | [](int index, std::string& label, float prob) -> void { |
| 120 | printf("Keyword \"%s\", index %d:, probability %f\n", |
| 121 | label.c_str(), |
| 122 | index, |
| 123 | prob); |
| 124 | }); |
| 125 | } |
| 126 | |
| 127 | return 0; |
| 128 | } |