alexander | f42f568 | 2021-07-16 11:30:56 +0100 | [diff] [blame] | 1 | # Copyright © 2021 Arm Ltd and Contributors. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | |
| 4 | """Keyword Spotting with PyArmNN demo for processing live microphone data or pre-recorded files.""" |
| 5 | |
| 6 | import sys |
| 7 | import os |
| 8 | from argparse import ArgumentParser |
| 9 | |
| 10 | import numpy as np |
| 11 | import sounddevice as sd |
| 12 | |
| 13 | script_dir = os.path.dirname(__file__) |
| 14 | sys.path.insert(1, os.path.join(script_dir, '..', 'common')) |
| 15 | |
| 16 | from network_executor import ArmnnNetworkExecutor |
Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 17 | from utils import prepare_input_data, dequantize_output |
alexander | f42f568 | 2021-07-16 11:30:56 +0100 | [diff] [blame] | 18 | from mfcc import AudioPreprocessor, MFCC, MFCCParams |
| 19 | from audio_utils import decode, display_text |
| 20 | from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio |
| 21 | |
| 22 | # Model Specific Labels |
| 23 | labels = {0: 'silence', |
| 24 | 1: 'unknown', |
| 25 | 2: 'yes', |
| 26 | 3: 'no', |
| 27 | 4: 'up', |
| 28 | 5: 'down', |
| 29 | 6: 'left', |
| 30 | 7: 'right', |
| 31 | 8: 'on', |
| 32 | 9: 'off', |
| 33 | 10: 'stop', |
| 34 | 11: 'go'} |
| 35 | |
| 36 | |
| 37 | def parse_args(): |
| 38 | parser = ArgumentParser(description="KWS with PyArmNN") |
| 39 | parser.add_argument( |
| 40 | "--audio_file_path", |
| 41 | required=False, |
| 42 | type=str, |
| 43 | help="Path to the audio file to perform KWS", |
| 44 | ) |
| 45 | parser.add_argument( |
| 46 | "--duration", |
| 47 | type=int, |
| 48 | default=0, |
| 49 | help="""Duration for recording audio in seconds. Values <= 0 result in infinite |
| 50 | recording. Defaults to infinite.""", |
| 51 | ) |
| 52 | parser.add_argument( |
| 53 | "--model_file_path", |
| 54 | required=True, |
| 55 | type=str, |
| 56 | help="Path to KWS model to use", |
| 57 | ) |
| 58 | parser.add_argument( |
| 59 | "--preferred_backends", |
| 60 | type=str, |
| 61 | nargs="+", |
| 62 | default=["CpuAcc", "CpuRef"], |
| 63 | help="""List of backends in order of preference for optimizing |
| 64 | subgraphs, falling back to the next backend in the list on unsupported |
| 65 | layers. Defaults to [CpuAcc, CpuRef]""", |
| 66 | ) |
| 67 | return parser.parse_args() |
| 68 | |
| 69 | |
| 70 | def recognise_speech(audio_data, network, preprocessor, threshold): |
| 71 | # Prepare the input Tensors |
Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 72 | input_data = prepare_input_data(audio_data, network.get_data_type(), network.get_input_quantization_scale(0), |
| 73 | network.get_input_quantization_offset(0), preprocessor) |
alexander | f42f568 | 2021-07-16 11:30:56 +0100 | [diff] [blame] | 74 | # Run inference |
Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 75 | output_result = network.run([input_data]) |
alexander | f42f568 | 2021-07-16 11:30:56 +0100 | [diff] [blame] | 76 | |
| 77 | dequantized_result = [] |
| 78 | for index, ofm in enumerate(output_result): |
Raviv Shalev | 97ddc06 | 2021-12-07 15:18:09 +0200 | [diff] [blame] | 79 | dequantized_result.append(dequantize_output(ofm, network.is_output_quantized(index), |
| 80 | network.get_output_quantization_scale(index), |
| 81 | network.get_output_quantization_offset(index))) |
alexander | f42f568 | 2021-07-16 11:30:56 +0100 | [diff] [blame] | 82 | |
| 83 | # Decode the text and display result if above threshold |
| 84 | decoded_result = decode(dequantized_result, labels) |
| 85 | |
| 86 | if decoded_result[1] > threshold: |
| 87 | display_text(decoded_result) |
| 88 | |
| 89 | |
| 90 | def main(args): |
| 91 | # Read command line args and invoke mic streaming if no file path supplied |
| 92 | audio_file = args.audio_file_path |
| 93 | if args.audio_file_path: |
| 94 | streaming_enabled = False |
| 95 | else: |
| 96 | streaming_enabled = True |
| 97 | # Create the ArmNN inference runner |
| 98 | network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends) |
| 99 | |
| 100 | # Specify model specific audio data requirements |
| 101 | # Overlap value specifies the number of samples to rewind between each data window |
| 102 | audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=2000, min_samples=16000, sampling_freq=16000, |
| 103 | mono=True) |
| 104 | |
| 105 | # Create the preprocessor |
| 106 | mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000, |
| 107 | num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024) |
| 108 | mfcc = MFCC(mfcc_params) |
| 109 | preprocessor = AudioPreprocessor(mfcc, model_input_size=49, stride=320) |
| 110 | |
| 111 | # Set threshold for displaying classification and commence stream or file processing |
| 112 | threshold = .90 |
| 113 | if streaming_enabled: |
| 114 | # Initialise audio stream |
| 115 | record_stream = CaptureAudioStream(audio_capture_params) |
| 116 | record_stream.set_stream_defaults() |
| 117 | record_stream.set_recording_duration(args.duration) |
| 118 | record_stream.countdown() |
| 119 | |
| 120 | with sd.InputStream(callback=record_stream.callback): |
| 121 | print("Recording audio. Please speak.") |
| 122 | while record_stream.is_active: |
| 123 | |
| 124 | audio_data = record_stream.capture_data() |
| 125 | recognise_speech(audio_data, network, preprocessor, threshold) |
| 126 | record_stream.is_first_window = False |
| 127 | print("\nFinished recording.") |
| 128 | |
| 129 | # If file path has been supplied read-in and run inference |
| 130 | else: |
| 131 | print("Processing Audio Frames...") |
| 132 | buffer = capture_audio(audio_file, audio_capture_params) |
| 133 | for audio_data in buffer: |
| 134 | recognise_speech(audio_data, network, preprocessor, threshold) |
| 135 | |
| 136 | |
| 137 | if __name__ == "__main__": |
| 138 | args = parse_args() |
| 139 | main(args) |