blob: 6dfa4cc806095f3d7a27f02bfedde38aa58abfcf [file] [log] [blame]
alexanderf42f5682021-07-16 11:30:56 +01001# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
2# SPDX-License-Identifier: MIT
3
4"""Keyword Spotting with PyArmNN demo for processing live microphone data or pre-recorded files."""
5
6import sys
7import os
8from argparse import ArgumentParser
9
10import numpy as np
11import sounddevice as sd
12
13script_dir = os.path.dirname(__file__)
14sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
15
16from network_executor import ArmnnNetworkExecutor
17from utils import prepare_input_tensors, dequantize_output
18from mfcc import AudioPreprocessor, MFCC, MFCCParams
19from audio_utils import decode, display_text
20from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio
21
22# Model Specific Labels
23labels = {0: 'silence',
24 1: 'unknown',
25 2: 'yes',
26 3: 'no',
27 4: 'up',
28 5: 'down',
29 6: 'left',
30 7: 'right',
31 8: 'on',
32 9: 'off',
33 10: 'stop',
34 11: 'go'}
35
36
37def parse_args():
38 parser = ArgumentParser(description="KWS with PyArmNN")
39 parser.add_argument(
40 "--audio_file_path",
41 required=False,
42 type=str,
43 help="Path to the audio file to perform KWS",
44 )
45 parser.add_argument(
46 "--duration",
47 type=int,
48 default=0,
49 help="""Duration for recording audio in seconds. Values <= 0 result in infinite
50 recording. Defaults to infinite.""",
51 )
52 parser.add_argument(
53 "--model_file_path",
54 required=True,
55 type=str,
56 help="Path to KWS model to use",
57 )
58 parser.add_argument(
59 "--preferred_backends",
60 type=str,
61 nargs="+",
62 default=["CpuAcc", "CpuRef"],
63 help="""List of backends in order of preference for optimizing
64 subgraphs, falling back to the next backend in the list on unsupported
65 layers. Defaults to [CpuAcc, CpuRef]""",
66 )
67 return parser.parse_args()
68
69
70def recognise_speech(audio_data, network, preprocessor, threshold):
71 # Prepare the input Tensors
72 input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
73 # Run inference
74 output_result = network.run(input_tensors)
75
76 dequantized_result = []
77 for index, ofm in enumerate(output_result):
78 dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index]))
79
80 # Decode the text and display result if above threshold
81 decoded_result = decode(dequantized_result, labels)
82
83 if decoded_result[1] > threshold:
84 display_text(decoded_result)
85
86
87def main(args):
88 # Read command line args and invoke mic streaming if no file path supplied
89 audio_file = args.audio_file_path
90 if args.audio_file_path:
91 streaming_enabled = False
92 else:
93 streaming_enabled = True
94 # Create the ArmNN inference runner
95 network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
96
97 # Specify model specific audio data requirements
98 # Overlap value specifies the number of samples to rewind between each data window
99 audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=2000, min_samples=16000, sampling_freq=16000,
100 mono=True)
101
102 # Create the preprocessor
103 mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000,
104 num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024)
105 mfcc = MFCC(mfcc_params)
106 preprocessor = AudioPreprocessor(mfcc, model_input_size=49, stride=320)
107
108 # Set threshold for displaying classification and commence stream or file processing
109 threshold = .90
110 if streaming_enabled:
111 # Initialise audio stream
112 record_stream = CaptureAudioStream(audio_capture_params)
113 record_stream.set_stream_defaults()
114 record_stream.set_recording_duration(args.duration)
115 record_stream.countdown()
116
117 with sd.InputStream(callback=record_stream.callback):
118 print("Recording audio. Please speak.")
119 while record_stream.is_active:
120
121 audio_data = record_stream.capture_data()
122 recognise_speech(audio_data, network, preprocessor, threshold)
123 record_stream.is_first_window = False
124 print("\nFinished recording.")
125
126 # If file path has been supplied read-in and run inference
127 else:
128 print("Processing Audio Frames...")
129 buffer = capture_audio(audio_file, audio_capture_params)
130 for audio_data in buffer:
131 recognise_speech(audio_data, network, preprocessor, threshold)
132
133
134if __name__ == "__main__":
135 args = parse_args()
136 main(args)