blob: 6dfa4cc806095f3d7a27f02bfedde38aa58abfcf [file] [log] [blame]
# Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
"""Keyword Spotting with PyArmNN demo for processing live microphone data or pre-recorded files."""
import sys
import os
from argparse import ArgumentParser
import numpy as np
import sounddevice as sd
script_dir = os.path.dirname(__file__)
sys.path.insert(1, os.path.join(script_dir, '..', 'common'))
from network_executor import ArmnnNetworkExecutor
from utils import prepare_input_tensors, dequantize_output
from mfcc import AudioPreprocessor, MFCC, MFCCParams
from audio_utils import decode, display_text
from audio_capture import AudioCaptureParams, CaptureAudioStream, capture_audio
# Model Specific Labels
labels = {0: 'silence',
1: 'unknown',
2: 'yes',
3: 'no',
4: 'up',
5: 'down',
6: 'left',
7: 'right',
8: 'on',
9: 'off',
10: 'stop',
11: 'go'}
def parse_args():
parser = ArgumentParser(description="KWS with PyArmNN")
parser.add_argument(
"--audio_file_path",
required=False,
type=str,
help="Path to the audio file to perform KWS",
)
parser.add_argument(
"--duration",
type=int,
default=0,
help="""Duration for recording audio in seconds. Values <= 0 result in infinite
recording. Defaults to infinite.""",
)
parser.add_argument(
"--model_file_path",
required=True,
type=str,
help="Path to KWS model to use",
)
parser.add_argument(
"--preferred_backends",
type=str,
nargs="+",
default=["CpuAcc", "CpuRef"],
help="""List of backends in order of preference for optimizing
subgraphs, falling back to the next backend in the list on unsupported
layers. Defaults to [CpuAcc, CpuRef]""",
)
return parser.parse_args()
def recognise_speech(audio_data, network, preprocessor, threshold):
# Prepare the input Tensors
input_tensors = prepare_input_tensors(audio_data, network.input_binding_info, preprocessor)
# Run inference
output_result = network.run(input_tensors)
dequantized_result = []
for index, ofm in enumerate(output_result):
dequantized_result.append(dequantize_output(ofm, network.output_binding_info[index]))
# Decode the text and display result if above threshold
decoded_result = decode(dequantized_result, labels)
if decoded_result[1] > threshold:
display_text(decoded_result)
def main(args):
# Read command line args and invoke mic streaming if no file path supplied
audio_file = args.audio_file_path
if args.audio_file_path:
streaming_enabled = False
else:
streaming_enabled = True
# Create the ArmNN inference runner
network = ArmnnNetworkExecutor(args.model_file_path, args.preferred_backends)
# Specify model specific audio data requirements
# Overlap value specifies the number of samples to rewind between each data window
audio_capture_params = AudioCaptureParams(dtype=np.float32, overlap=2000, min_samples=16000, sampling_freq=16000,
mono=True)
# Create the preprocessor
mfcc_params = MFCCParams(sampling_freq=16000, num_fbank_bins=40, mel_lo_freq=20, mel_hi_freq=4000,
num_mfcc_feats=10, frame_len=640, use_htk_method=True, n_fft=1024)
mfcc = MFCC(mfcc_params)
preprocessor = AudioPreprocessor(mfcc, model_input_size=49, stride=320)
# Set threshold for displaying classification and commence stream or file processing
threshold = .90
if streaming_enabled:
# Initialise audio stream
record_stream = CaptureAudioStream(audio_capture_params)
record_stream.set_stream_defaults()
record_stream.set_recording_duration(args.duration)
record_stream.countdown()
with sd.InputStream(callback=record_stream.callback):
print("Recording audio. Please speak.")
while record_stream.is_active:
audio_data = record_stream.capture_data()
recognise_speech(audio_data, network, preprocessor, threshold)
record_stream.is_first_window = False
print("\nFinished recording.")
# If file path has been supplied read-in and run inference
else:
print("Processing Audio Frames...")
buffer = capture_audio(audio_file, audio_capture_params)
for audio_data in buffer:
recognise_speech(audio_data, network, preprocessor, threshold)
if __name__ == "__main__":
args = parse_args()
main(args)