alexander | f42f568 | 2021-07-16 11:30:56 +0100 | [diff] [blame] | 1 | # Copyright © 2021 Arm Ltd and Contributors. All rights reserved. |
| 2 | # SPDX-License-Identifier: MIT |
| 3 | """Contains CaptureAudioStream class for capturing chunks of audio data from incoming |
| 4 | stream and generic capture_audio function for capturing from files.""" |
| 5 | import collections |
| 6 | import time |
| 7 | from queue import Queue |
| 8 | from typing import Generator |
| 9 | |
| 10 | import numpy as np |
| 11 | import sounddevice as sd |
| 12 | import soundfile as sf |
| 13 | |
| 14 | AudioCaptureParams = collections.namedtuple('AudioCaptureParams', |
| 15 | ['dtype', 'overlap', 'min_samples', 'sampling_freq', 'mono']) |
| 16 | |
| 17 | |
| 18 | def capture_audio(audio_file_path, params_tuple) -> Generator[np.ndarray, None, None]: |
| 19 | """Creates a generator that yields audio data from a file. Data is padded with |
| 20 | zeros if necessary to make up minimum number of samples. |
| 21 | Args: |
| 22 | audio_file_path: Path to audio file provided by user. |
| 23 | params_tuple: Sampling parameters for model used |
| 24 | Yields: |
| 25 | Blocks of audio data of minimum sample size. |
| 26 | """ |
| 27 | with sf.SoundFile(audio_file_path) as audio_file: |
| 28 | for block in audio_file.blocks( |
| 29 | blocksize=params_tuple.min_samples, |
| 30 | dtype=params_tuple.dtype, |
| 31 | always_2d=True, |
| 32 | fill_value=0, |
| 33 | overlap=params_tuple.overlap |
| 34 | ): |
| 35 | if params_tuple.mono and block.shape[0] > 1: |
| 36 | block = np.mean(block, dtype=block.dtype, axis=1) |
| 37 | yield block |
| 38 | |
| 39 | |
| 40 | class CaptureAudioStream: |
| 41 | |
| 42 | def __init__(self, audio_capture_params): |
| 43 | self.audio_capture_params = audio_capture_params |
| 44 | self.collection = np.zeros(self.audio_capture_params.min_samples + self.audio_capture_params.overlap).astype( |
| 45 | dtype=self.audio_capture_params.dtype) |
| 46 | self.is_active = True |
| 47 | self.is_first_window = True |
| 48 | self.duration = False |
| 49 | self.block_count = 0 |
| 50 | self.current_block = 0 |
| 51 | self.queue = Queue(2) |
| 52 | |
| 53 | def set_stream_defaults(self): |
| 54 | """Discovers input devices on the system and sets default stream parameters.""" |
| 55 | print(sd.query_devices()) |
| 56 | device = input("Select input device by index or name: ") |
| 57 | |
| 58 | try: |
| 59 | sd.default.device = int(device) |
| 60 | except ValueError: |
| 61 | sd.default.device = str(device) |
| 62 | |
| 63 | sd.default.samplerate = self.audio_capture_params.sampling_freq |
| 64 | sd.default.blocksize = self.audio_capture_params.min_samples |
| 65 | sd.default.dtype = self.audio_capture_params.dtype |
| 66 | sd.default.channels = 1 if self.audio_capture_params.mono else 2 |
| 67 | |
| 68 | def set_recording_duration(self, duration): |
| 69 | """Sets a time duration (in integer seconds) for recording audio. Total time duration is |
| 70 | adjusted to a minimum based on the parameters of the model used. Durations less than 1 |
| 71 | result in endless recording. |
| 72 | |
| 73 | Args: |
| 74 | duration (int): User-provided command line argument for time duration of recording. |
| 75 | """ |
| 76 | if duration > 0: |
| 77 | min_duration = int( |
| 78 | np.ceil(self.audio_capture_params.min_samples / self.audio_capture_params.sampling_freq) |
| 79 | ) |
| 80 | if duration < min_duration: |
| 81 | print(f"Minimum duration must be {min_duration} seconds of audio") |
| 82 | print(f"Setting minimum recording duration...") |
| 83 | duration = min_duration |
| 84 | |
| 85 | print(f"Recording duration is {duration} seconds") |
| 86 | self.duration = self.audio_capture_params.sampling_freq * duration |
| 87 | self.block_count, remainder_samples = divmod( |
| 88 | self.duration, self.audio_capture_params.min_samples |
| 89 | ) |
| 90 | |
| 91 | if remainder_samples > 0.5 * self.audio_capture_params.sampling_freq: |
| 92 | self.block_count += 1 |
| 93 | else: |
| 94 | self.duration = False # Record forever |
| 95 | |
| 96 | def countdown(self, delay=3): |
| 97 | """3 second countdown prior to recording audio.""" |
| 98 | print("Beginning recording in...") |
| 99 | for i in range(delay, 0, -1): |
| 100 | print(f"{i}...") |
| 101 | time.sleep(1) |
| 102 | |
| 103 | def update(self): |
| 104 | """If a duration has been set, increments a counter to update the number of blocks of audio |
| 105 | data left to be collected. The stream is deactivated upon reaching the maximum block count |
| 106 | determined by the duration. |
| 107 | """ |
| 108 | if self.duration: |
| 109 | self.current_block += 1 |
| 110 | if self.current_block == self.block_count: |
| 111 | self.is_active = False |
| 112 | |
| 113 | def capture_data(self): |
| 114 | """Gets the next window of audio data by retrieving the newest data from a queue and |
| 115 | shifting the position of the data in the collection. Overlap values of less than `min_samples` are supported. |
| 116 | """ |
| 117 | new_data = self.queue.get() |
| 118 | |
| 119 | if self.is_first_window or self.audio_capture_params.overlap == 0: |
| 120 | self.collection[:self.audio_capture_params.min_samples] = new_data[:] |
| 121 | |
| 122 | elif self.audio_capture_params.overlap < self.audio_capture_params.min_samples: |
| 123 | # |
| 124 | self.collection[0:self.audio_capture_params.overlap] = \ |
| 125 | self.collection[(self.audio_capture_params.min_samples - self.audio_capture_params.overlap): |
| 126 | self.audio_capture_params.min_samples] |
| 127 | |
| 128 | self.collection[self.audio_capture_params.overlap:( |
| 129 | self.audio_capture_params.overlap + self.audio_capture_params.min_samples)] = new_data[:] |
| 130 | else: |
| 131 | raise ValueError( |
| 132 | "Capture Error: Overlap must be less than {}".format(self.audio_capture_params.min_samples)) |
| 133 | audio_data = self.collection[0:self.audio_capture_params.min_samples] |
| 134 | return np.asarray(audio_data).astype(self.audio_capture_params.dtype) |
| 135 | |
| 136 | def callback(self, data, frames, time, status): |
| 137 | """Places audio data from active stream into a queue for processing. |
| 138 | Update counter if recording duration is finite. |
| 139 | """ |
| 140 | |
| 141 | if self.duration: |
| 142 | self.update() |
| 143 | |
| 144 | if self.audio_capture_params.mono: |
| 145 | audio_data = data.copy().flatten() |
| 146 | else: |
| 147 | audio_data = data.copy() |
| 148 | |
| 149 | self.queue.put(audio_data) |