From 225178f6fb439f4e7c92a6189571c49cd303d60b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Thu, 9 Jan 2025 21:14:06 +0100 Subject: [PATCH] feat: functional multithreading transcription --- src/rai/rai/agents/voice_agent.py | 32 +++++++++++++------ .../communication/sound_device_connector.py | 9 ++---- src/rai_asr/rai_asr/models/local_whisper.py | 31 +++++++++++++++--- .../test_sound_device_connector.py | 1 - 4 files changed, 53 insertions(+), 20 deletions(-) diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py index 27caf5f1..d0508321 100644 --- a/src/rai/rai/agents/voice_agent.py +++ b/src/rai/rai/agents/voice_agent.py @@ -13,9 +13,10 @@ # limitations under the License. +import logging import time from threading import Event, Lock, Thread -from typing import Any, List, TypedDict +from typing import Any, List, Optional, TypedDict from uuid import uuid4 import numpy as np @@ -40,7 +41,12 @@ def __init__( transcription_model: BaseTranscriptionModel, vad: BaseVoiceDetectionModel, grace_period: float = 1.0, + logger: Optional[logging.Logger] = None, ): + if logger is None: + self.logger = logging.getLogger(__name__) + else: + self.logger = logger microphone = StreamingAudioInputDevice() microphone.configure_device( target=str(microphone_device_id), config=microphone_config @@ -87,16 +93,20 @@ def run(self): ) def stop(self): + self.logger.info("Stopping voice agent") self.running = False self.connectors["microphone"].terminate_action(self.listener_handle) - to_finish = list(self.transcription_threads.keys()) - while len(to_finish) > 0: + to_finish = len(list(self.transcription_threads.keys())) + while to_finish > 0: for thread_id in self.transcription_threads: if self.transcription_threads[thread_id]["event"].is_set(): self.transcription_threads[thread_id]["thread"].join() - to_finish.remove(thread_id) + to_finish -= 1 else: - print(f"Waiting for transcription of {thread_id} to finish") + self.logger.info( + f"Waiting for transcription of {thread_id} to finish..." + ) + self.logger.info("Voice agent stopped") def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): sample_time = time.time() @@ -112,7 +122,7 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): should_record = self.should_record(indata, output_parameters) if should_record: - print("Start recording") + self.logger.info("starting recording...") self.recording_started = True thread_id = str(uuid4())[0:8] transcription_thread = Thread( @@ -129,13 +139,14 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): } if voice_detected: + self.logger.debug("Voice detected... resetting grace period") self.grace_period_start = sample_time if ( self.recording_started and sample_time - self.grace_period_start > self.grace_period ): - print("Stop recording") + self.logger.info("Grace period ended... stopping recording") self.recording_started = False self.grace_period_start = 0 with self.sample_buffer_lock: @@ -148,12 +159,12 @@ def should_record( ) -> bool: for model in self.should_record_pipeline: detected, output = model.detected(audio_data, input_parameters) - print(f"Detected: {detected}: {output}") if detected: return True return False def transcription_thread(self, identifier: str): + self.logger.info(f"transcription thread {identifier} started") with self.transcription_lock: while self.active_thread == identifier: with self.sample_buffer_lock: @@ -171,7 +182,10 @@ def transcription_thread(self, identifier: str): audio_data = np.concatenate(audio_data) self.transcription_model.transcribe(audio_data) del self.buffer_reminders[identifier] + # self.transcription_model.save_wav(f"{identifier}.wav") transcription = self.transcription_model.consume_transcription() self.transcription_threads[identifier]["transcription"] = transcription self.transcription_threads[identifier]["event"].set() - # TODO: sending the transcription + # TODO: sending the transcription once https://github.com/RobotecAI/rai/pull/360 is merged + self.logger.info(f"transcription thread {identifier} finished") + print(transcription) diff --git a/src/rai/rai/communication/sound_device_connector.py b/src/rai/rai/communication/sound_device_connector.py index d209ecd8..323bfbbc 100644 --- a/src/rai/rai/communication/sound_device_connector.py +++ b/src/rai/rai/communication/sound_device_connector.py @@ -30,7 +30,6 @@ def __init__(self, msg: str): class AudioInputDeviceConfig(TypedDict): block_size: int consumer_sampling_rate: int - target_sampling_rate: int dtype: str device_number: Optional[int] @@ -44,7 +43,6 @@ class ConfiguredAudioInputDevice: sample_rate (int): Device sample rate consumer_sampling_rate (int): The sampling rate of the consumer window_size_samples (int): The size of the window in samples - target_sampling_rate (int): The target sampling rate dtype (str): The data type of the audio samples """ @@ -58,7 +56,6 @@ def __init__(self, config: AudioInputDeviceConfig): self.window_size_samples = int( config["block_size"] * self.sample_rate / config["consumer_sampling_rate"] ) - self.target_sampling_rate = int(config["target_sampling_rate"]) self.dtype = config["dtype"] @@ -108,9 +105,9 @@ def start_action( def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags): indata = indata.flatten() - sample_time_length = len(indata) / target_device.target_sampling_rate - if target_device.sample_rate != target_device.target_sampling_rate: - indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore + sample_time_length = len(indata) / target_device.sample_rate + if target_device.sample_rate != target_device.consumer_sampling_rate: + indata = resample(indata, int(sample_time_length * target_device.consumer_sampling_rate)) # type: ignore flag_dict = { "input_overflow": status.input_overflow, "input_underflow": status.input_underflow, diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index 84681c0e..77756c3c 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -26,20 +26,43 @@ class LocalWhisper(BaseTranscriptionModel): def __init__(self, model_name: str, sample_rate: int, language: str = "en"): super().__init__(model_name, sample_rate, language) if torch.cuda.is_available(): - print("Using CUDA") self.whisper = whisper.load_model(self.model_name, device="cuda") else: self.whisper = whisper.load_model(self.model_name) - self.samples = None + # TODO: remove sample storage before PR is merged, this is just to enable saving wav files for debugging + # self.samples = None + + def consume_transcription(self) -> str: + ret = super().consume_transcription() + # self.samples = None + return ret + + # def save_wav(self, output_filename: str): + # assert self.samples is not None, "No samples to save" + # combined_samples = self.samples + # if combined_samples.dtype.kind == "f": + # combined_samples = np.clip(combined_samples, -1.0, 1.0) + # combined_samples = (combined_samples * 32767).astype(np.int16) + # elif combined_samples.dtype != np.int16: + # combined_samples = combined_samples.astype(np.int16) + + # with wave.open(output_filename, "wb") as wav_file: + # n_channels = 1 + # sampwidth = 2 + # wav_file.setnchannels(n_channels) + # wav_file.setsampwidth(sampwidth) + # wav_file.setframerate(self.sample_rate) + # wav_file.writeframes(combined_samples.tobytes()) def transcribe(self, data: NDArray[np.int16]): + # self.samples = ( + # np.concatenate((self.samples, data)) if self.samples is not None else data + # ) normalized_data = data.astype(np.float32) / 32768.0 - print("Starting transcription") result = whisper.transcribe( self.whisper, normalized_data ) # TODO: handling of additional transcribe arguments (perhaps in model init) - print("Finished transcription") transcription = result["text"] transcription = cast(str, transcription) self.latest_transcription += transcription diff --git a/tests/communication/test_sound_device_connector.py b/tests/communication/test_sound_device_connector.py index cb6d7210..f36f4815 100644 --- a/tests/communication/test_sound_device_connector.py +++ b/tests/communication/test_sound_device_connector.py @@ -31,7 +31,6 @@ def device_config(): return { "block_size": 1024, "consumer_sampling_rate": 44100, - "target_sampling_rate": 16000, "dtype": "float32", }