From 5ec3ec30545313c6e1b1b3576f9dc01f234649c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Mon, 6 Jan 2025 22:51:15 +0100 Subject: [PATCH 01/16] feat: add configurable voice agent basic version --- src/rai/rai/agents/__init__.py | 2 + src/rai/rai/agents/base.py | 36 ++++++++++ src/rai/rai/agents/voice_agent.py | 97 ++++++++++++++++++++++++++ src/rai/rai/communication/__init__.py | 7 +- src/rai_asr/rai_asr/models/__init__.py | 17 +++++ src/rai_asr/rai_asr/models/base.py | 28 ++++++++ 6 files changed, 186 insertions(+), 1 deletion(-) create mode 100644 src/rai/rai/agents/base.py create mode 100644 src/rai/rai/agents/voice_agent.py create mode 100644 src/rai_asr/rai_asr/models/__init__.py create mode 100644 src/rai_asr/rai_asr/models/base.py diff --git a/src/rai/rai/agents/__init__.py b/src/rai/rai/agents/__init__.py index dc101282..2b7d4461 100644 --- a/src/rai/rai/agents/__init__.py +++ b/src/rai/rai/agents/__init__.py @@ -15,9 +15,11 @@ from rai.agents.conversational_agent import create_conversational_agent from rai.agents.state_based import create_state_based_agent from rai.agents.tool_runner import ToolRunner +from rai.agents.voice_agent import VoiceRecognitionAgent __all__ = [ "ToolRunner", "create_conversational_agent", "create_state_based_agent", + "VoiceRecognitionAgent", ] diff --git a/src/rai/rai/agents/base.py b/src/rai/rai/agents/base.py new file mode 100644 index 00000000..285691c6 --- /dev/null +++ b/src/rai/rai/agents/base.py @@ -0,0 +1,36 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from typing import Optional + +from rai.communication import BaseConnector + + +class BaseAgent(ABC): + def __init__( + self, connectors: Optional[dict[str, BaseConnector]] = None, *args, **kwargs + ): + if connectors is None: + connectors = {} + self.connectors: dict[str, BaseConnector] = connectors + + @abstractmethod + def setup(self, *args, **kwargs): + pass + + @abstractmethod + def run(self, *args, **kwargs): + pass diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py new file mode 100644 index 00000000..d0f3841e --- /dev/null +++ b/src/rai/rai/agents/voice_agent.py @@ -0,0 +1,97 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from threading import Lock, Thread +from typing import Any, List, Tuple + +import numpy as np +from numpy.typing import NDArray + +from rai.agents.base import BaseAgent +from rai.communication import AudioInputDeviceConfig, StreamingAudioInputDevice +from rai_asr.models.base import BaseVoiceDetectionModel + + +class VoiceRecognitionAgent(BaseAgent): + def __init__(self): + super().__init__(connectors={"microphone": StreamingAudioInputDevice()}) + self.should_record_pipeline: List[BaseVoiceDetectionModel] = [] + self.should_stop_pipeline: List[BaseVoiceDetectionModel] = [] + self.transcription_lock = Lock() + self.shared_samples = [] + self.recording_started = False + self.ran_setup = False + + def __call__(self): + self.run() + + def setup( + self, microphone_device_id: int, microphone_config: AudioInputDeviceConfig + ): + assert isinstance(self.connectors["microphone"], StreamingAudioInputDevice) + self.microphone_device_id = str(microphone_device_id) + self.connectors["microphone"].configure_device( + target=self.microphone_device_id, config=microphone_config + ) + self.ran_setup = True + + def run(self): + self.listener_handle = self.connectors["microphone"].start_action( + self.microphone_device_id, self.on_new_sample + ) + self.transcription_thread = Thread(target=self._transcription_function) + self.transcription_thread.start() + + def stop(self): + self.connectors["microphone"].terminate_action(self.listener_handle) + self.transcription_thread.join() + + def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): + should_stop, should_cancel = self.should_stop_recording(indata) + print(indata) + if should_cancel: + self.cancel_task() + if (self.recording_started and not should_stop) or ( + self.should_start_recording(indata) + ): + with self.transcription_lock: + self.shared_samples.extend(indata) + + def should_start_recording(self, audio_data: NDArray[np.int16]) -> bool: + output_parameters = {} + for model in self.should_record_pipeline: + should_listen, output_parameters = model.detected( + audio_data, output_parameters + ) + if not should_listen: + return False + return True + + def should_stop_recording(self, audio_data: NDArray[np.int16]) -> Tuple[bool, bool]: + output_parameters = {} + for model in self.should_stop_pipeline: + should_listen, output_parameters = model.detected( + audio_data, output_parameters + ) + # TODO: Add handling output parametrs for checking if should cancel + if should_listen: + return False, False + return True, False + + def _transcription_function(self): + with self.transcription_lock: + samples = np.array(self.shared_samples) + print(samples) + self.shared_samples = [] diff --git a/src/rai/rai/communication/__init__.py b/src/rai/rai/communication/__init__.py index 04c1fc4f..c49ac9e4 100644 --- a/src/rai/rai/communication/__init__.py +++ b/src/rai/rai/communication/__init__.py @@ -15,7 +15,11 @@ from .ari_connector import ARIConnector, ARIMessage from .base_connector import BaseConnector, BaseMessage from .hri_connector import HRIConnector, HRIMessage, HRIPayload -from .sound_device_connector import SoundDeviceError, StreamingAudioInputDevice +from .sound_device_connector import ( + AudioInputDeviceConfig, + SoundDeviceError, + StreamingAudioInputDevice, +) __all__ = [ "ARIConnector", @@ -27,4 +31,5 @@ "HRIPayload", "StreamingAudioInputDevice", "SoundDeviceError", + "AudioInputDeviceConfig", ] diff --git a/src/rai_asr/rai_asr/models/__init__.py b/src/rai_asr/rai_asr/models/__init__.py new file mode 100644 index 00000000..64be0f21 --- /dev/null +++ b/src/rai_asr/rai_asr/models/__init__.py @@ -0,0 +1,17 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseVoiceDetectionModel + +__all__ = ["BaseVoiceDetectionModel"] diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py new file mode 100644 index 00000000..106120c1 --- /dev/null +++ b/src/rai_asr/rai_asr/models/base.py @@ -0,0 +1,28 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from abc import ABC, abstractmethod +from typing import Any, Tuple + +from numpy._typing import NDArray + + +class BaseVoiceDetectionModel(ABC): + + @abstractmethod + def detected( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> Tuple[bool, dict[str, Any]]: + pass From 47044e70a0a8765a60b7a46a5faabf45f981fd03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 7 Jan 2025 12:54:41 +0100 Subject: [PATCH 02/16] feat: add silero vad model --- .pre-commit-config.yaml | 2 +- src/rai_asr/rai_asr/models/silero_vad.py | 61 ++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 src/rai_asr/rai_asr/models/silero_vad.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c312f93e..2caf3cb7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,4 +44,4 @@ repos: rev: 7.1.0 hooks: - id: flake8 - args: ["--ignore=E501,E731,W503,W504"] + args: ["--ignore=E501,E731,W503,W504,E203"] diff --git a/src/rai_asr/rai_asr/models/silero_vad.py b/src/rai_asr/rai_asr/models/silero_vad.py new file mode 100644 index 00000000..1f6d0bf0 --- /dev/null +++ b/src/rai_asr/rai_asr/models/silero_vad.py @@ -0,0 +1,61 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple + +import numpy as np +import torch +from numpy.typing import NDArray + +from rai_asr.models import BaseVoiceDetectionModel + + +class SileroVAD(BaseVoiceDetectionModel): + def __init__(self, sampling_rate=16000, threshold=0.5): + super(SileroVAD, self).__init__() + self.model_name = "silero_vad" + self.model, _ = torch.hub.load( + repo_or_dir="snakers4/silero-vad", + model=self.model_name, + ) # type: ignore + # NOTE: See silero vad implementation: https://github.com/snakers4/silero-vad/blob/9060f664f20eabb66328e4002a41479ff288f14c/src/silero_vad/utils_vad.py#L61 + if sampling_rate == 16000: + self.sampling_rate = 16000 + self.window_size = 512 + elif sampling_rate == 8000: + self.sampling_rate = 8000 + self.window_size = 256 + else: + raise ValueError( + "Only 8000 and 16000 sampling rates are supported" + ) # TODO: consider if this should be a ValueError or something else + self.threshold = threshold + + def int2float(self, sound: NDArray[np.int16]): + converted_sound = sound.astype("float32") + converted_sound *= 1 / 32768 + converted_sound = converted_sound.squeeze() + return converted_sound + + def detect( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> Tuple[bool, dict[str, Any]]: + vad_confidence = self.model( + torch.tensor(self.int2float(audio_data[-self.window_size :])), + self.sampling_rate, + ).item() + ret = input_parameters.copy() + ret.update({self.model_name: {"vad_confidence": vad_confidence}}) + + return vad_confidence > self.threshold, ret From b67b992a4578ce27d52aaf1447f5aacce9c4cc62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 7 Jan 2025 17:11:08 +0100 Subject: [PATCH 03/16] feat: add open wake word model --- src/rai_asr/rai_asr/models/__init__.py | 6 ++- src/rai_asr/rai_asr/models/open_wake_word.py | 48 ++++++++++++++++++++ src/rai_asr/rai_asr/models/silero_vad.py | 3 +- 3 files changed, 54 insertions(+), 3 deletions(-) create mode 100644 src/rai_asr/rai_asr/models/open_wake_word.py diff --git a/src/rai_asr/rai_asr/models/__init__.py b/src/rai_asr/rai_asr/models/__init__.py index 64be0f21..a5b09f3c 100644 --- a/src/rai_asr/rai_asr/models/__init__.py +++ b/src/rai_asr/rai_asr/models/__init__.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import BaseVoiceDetectionModel +from rai_asr.models.base import BaseVoiceDetectionModel +from rai_asr.models.open_wake_word import OpenWakeWord +from rai_asr.models.silero_vad import SileroVAD -__all__ = ["BaseVoiceDetectionModel"] +__all__ = ["BaseVoiceDetectionModel", "SileroVAD", "OpenWakeWord"] diff --git a/src/rai_asr/rai_asr/models/open_wake_word.py b/src/rai_asr/rai_asr/models/open_wake_word.py new file mode 100644 index 00000000..3daadbcf --- /dev/null +++ b/src/rai_asr/rai_asr/models/open_wake_word.py @@ -0,0 +1,48 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Tuple + +from numpy.typing import NDArray +from openwakeword.model import Model as OWWModel +from openwakeword.utils import download_models + +from rai_asr.models import BaseVoiceDetectionModel + + +class OpenWakeWord(BaseVoiceDetectionModel): + def __init__(self, wake_word_model_path: str, threshold: float = 0.5): + super(OpenWakeWord, self).__init__() + self.model_name = "open_wake_word" + download_models() + self.model = OWWModel( + wakeword_models=[ + wake_word_model_path, + ], + inference_framework="onnx", + ) + self.threshold = threshold + + def detected( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> Tuple[bool, dict[str, Any]]: + print(len(audio_data)) + predictions = self.model.predict(audio_data) + ret = input_parameters.copy() + ret.update({self.model_name: {"predictions": predictions}}) + for key, value in predictions.items(): + if value > self.threshold: + self.model.reset() + return True, ret + return False, ret diff --git a/src/rai_asr/rai_asr/models/silero_vad.py b/src/rai_asr/rai_asr/models/silero_vad.py index 1f6d0bf0..91211ef9 100644 --- a/src/rai_asr/rai_asr/models/silero_vad.py +++ b/src/rai_asr/rai_asr/models/silero_vad.py @@ -48,7 +48,7 @@ def int2float(self, sound: NDArray[np.int16]): converted_sound = converted_sound.squeeze() return converted_sound - def detect( + def detected( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: vad_confidence = self.model( @@ -57,5 +57,6 @@ def detect( ).item() ret = input_parameters.copy() ret.update({self.model_name: {"vad_confidence": vad_confidence}}) + self.model.reset_states() # NOTE: see streaming example at the bottom https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies return vad_confidence > self.threshold, ret From 925487bf61519f837467c36c9f7d937e30571065 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 7 Jan 2025 17:32:11 +0100 Subject: [PATCH 04/16] refactor: move transcription models for consistency --- src/rai_asr/rai_asr/asr_clients.py | 2 + src/rai_asr/rai_asr/models/__init__.py | 13 ++++- src/rai_asr/rai_asr/models/base.py | 15 ++++++ src/rai_asr/rai_asr/models/local_whisper.py | 32 +++++++++++++ src/rai_asr/rai_asr/models/open_ai_whisper.py | 47 +++++++++++++++++++ 5 files changed, 107 insertions(+), 2 deletions(-) create mode 100644 src/rai_asr/rai_asr/models/local_whisper.py create mode 100644 src/rai_asr/rai_asr/models/open_ai_whisper.py diff --git a/src/rai_asr/rai_asr/asr_clients.py b/src/rai_asr/rai_asr/asr_clients.py index df538509..e08d0afd 100644 --- a/src/rai_asr/rai_asr/asr_clients.py +++ b/src/rai_asr/rai_asr/asr_clients.py @@ -24,6 +24,8 @@ from scipy.io import wavfile from whisper.transcribe import transcribe +# WARN: This file is going to be removed in favour of rai_asr.models + class ASRModel: def __init__(self, model_name: str, sample_rate: int, language: str = "en"): diff --git a/src/rai_asr/rai_asr/models/__init__.py b/src/rai_asr/rai_asr/models/__init__.py index a5b09f3c..1d1a7e9d 100644 --- a/src/rai_asr/rai_asr/models/__init__.py +++ b/src/rai_asr/rai_asr/models/__init__.py @@ -12,8 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from rai_asr.models.base import BaseVoiceDetectionModel +from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel +from rai_asr.models.local_whisper import LocalWhisper +from rai_asr.models.open_ai_whisper import OpenAIWhisper from rai_asr.models.open_wake_word import OpenWakeWord from rai_asr.models.silero_vad import SileroVAD -__all__ = ["BaseVoiceDetectionModel", "SileroVAD", "OpenWakeWord"] +__all__ = [ + "BaseVoiceDetectionModel", + "SileroVAD", + "OpenWakeWord", + "BaseTranscriptionModel", + "LocalWhisper", + "OpenAIWhisper", +] diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py index 106120c1..20ec17c5 100644 --- a/src/rai_asr/rai_asr/models/base.py +++ b/src/rai_asr/rai_asr/models/base.py @@ -16,6 +16,7 @@ from abc import ABC, abstractmethod from typing import Any, Tuple +import numpy as np from numpy._typing import NDArray @@ -26,3 +27,17 @@ def detected( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: pass + + +class BaseTranscriptionModel(ABC): + def __init__(self, model_name: str, sample_rate: int, language: str = "en"): + self.model_name = model_name + self.sample_rate = sample_rate + self.language = language + + @abstractmethod + def transcribe(self, data: NDArray[np.int16]) -> str: + pass + + def __call__(self, data: NDArray[np.int16]) -> str: + return self.transcribe(data) diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py new file mode 100644 index 00000000..0a0b35b9 --- /dev/null +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -0,0 +1,32 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import whisper +from numpy._typing import NDArray + +from rai_asr.models.base import BaseTranscriptionModel + + +class LocalWhisper(BaseTranscriptionModel): + def __init__(self, model_name: str, sample_rate: int, language: str = "en"): + super().__init__(model_name, sample_rate, language) + self.whisper = whisper.load_model(self.model_name) + + def transcribe(self, data: NDArray[np.int16]) -> str: + result = whisper.transcribe(self.whisper, data.astype(np.float32) / 32768.0) + transcription = result["text"] + # NOTE: this is only for type enforcement, doesn't need to work on runtime + assert isinstance(transcription, str) + return transcription diff --git a/src/rai_asr/rai_asr/models/open_ai_whisper.py b/src/rai_asr/rai_asr/models/open_ai_whisper.py new file mode 100644 index 00000000..2706c786 --- /dev/null +++ b/src/rai_asr/rai_asr/models/open_ai_whisper.py @@ -0,0 +1,47 @@ +# Copyright (C) 2024 Robotec.AI +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +from functools import partial + +import numpy as np +from numpy.typing import NDArray +from openai import OpenAI +from scipy.io import wavfile + +from rai_asr.models.base import BaseTranscriptionModel + + +class OpenAIWhisper(BaseTranscriptionModel): + def __init__(self, model_name: str, sample_rate: int, language: str = "en"): + super().__init__(model_name, sample_rate, language) + api_key = os.getenv("OPENAI_API_KEY") + if api_key is None: + raise ValueError("OPENAI_API_KEY environment variable is not set.") + self.api_key = api_key + self.openai_client = OpenAI() + self.model = partial( + self.openai_client.audio.transcriptions.create, + model=self.model_name, + ) + + def transcribe(self, data: NDArray[np.int16]) -> str: + with io.BytesIO() as temp_wav_buffer: + wavfile.write(temp_wav_buffer, self.sample_rate, data) + temp_wav_buffer.seek(0) + temp_wav_buffer.name = "temp.wav" + response = self.model(file=temp_wav_buffer, language=self.language) + transcription = response.text + return transcription From 9689ea982ab9f9e49ce663247ab40d570d0473b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Wed, 8 Jan 2025 13:49:48 +0100 Subject: [PATCH 05/16] fix: apply suggestions from preliminary review --- src/rai/rai/agents/voice_agent.py | 59 +++++++++++++------ src/rai_asr/rai_asr/models/base.py | 7 ++- src/rai_asr/rai_asr/models/local_whisper.py | 23 ++++++-- src/rai_asr/rai_asr/models/open_ai_whisper.py | 13 +++- src/rai_asr/rai_asr/models/silero_vad.py | 4 +- 5 files changed, 76 insertions(+), 30 deletions(-) diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py index d0f3841e..2657bfdd 100644 --- a/src/rai/rai/agents/voice_agent.py +++ b/src/rai/rai/agents/voice_agent.py @@ -13,15 +13,16 @@ # limitations under the License. +import time from threading import Lock, Thread -from typing import Any, List, Tuple +from typing import Any, List, cast import numpy as np from numpy.typing import NDArray from rai.agents.base import BaseAgent from rai.communication import AudioInputDeviceConfig, StreamingAudioInputDevice -from rai_asr.models.base import BaseVoiceDetectionModel +from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel class VoiceRecognitionAgent(BaseAgent): @@ -38,16 +39,34 @@ def __call__(self): self.run() def setup( - self, microphone_device_id: int, microphone_config: AudioInputDeviceConfig + self, + microphone_device_id: int, # TODO: Change to name based instead of id based identification + microphone_config: AudioInputDeviceConfig, + transcription_model: BaseTranscriptionModel, ): - assert isinstance(self.connectors["microphone"], StreamingAudioInputDevice) + self.connectors["microphone"] = cast( + StreamingAudioInputDevice, self.connectors["microphone"] + ) self.microphone_device_id = str(microphone_device_id) self.connectors["microphone"].configure_device( target=self.microphone_device_id, config=microphone_config ) + self.transcription_model = transcription_model self.ran_setup = True + self.running = False + + def add_detection_model( + self, model: BaseVoiceDetectionModel, pipeline: str = "record" + ): + if pipeline == "record": + self.should_record_pipeline.append(model) + elif pipeline == "stop": + self.should_stop_pipeline.append(model) + else: + raise ValueError("Pipeline should be either 'record' or 'stop'") def run(self): + self.running = True self.listener_handle = self.connectors["microphone"].start_action( self.microphone_device_id, self.on_new_sample ) @@ -55,17 +74,15 @@ def run(self): self.transcription_thread.start() def stop(self): + self.running = False self.connectors["microphone"].terminate_action(self.listener_handle) self.transcription_thread.join() def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): - should_stop, should_cancel = self.should_stop_recording(indata) - print(indata) - if should_cancel: - self.cancel_task() - if (self.recording_started and not should_stop) or ( - self.should_start_recording(indata) - ): + should_stop = self.should_stop_recording(indata) + if self.should_start_recording(indata): + self.recording_started = True + if self.recording_started and not should_stop: with self.transcription_lock: self.shared_samples.extend(indata) @@ -75,23 +92,27 @@ def should_start_recording(self, audio_data: NDArray[np.int16]) -> bool: should_listen, output_parameters = model.detected( audio_data, output_parameters ) + print(should_listen, output_parameters) if not should_listen: return False return True - def should_stop_recording(self, audio_data: NDArray[np.int16]) -> Tuple[bool, bool]: + def should_stop_recording(self, audio_data: NDArray[np.int16]) -> bool: output_parameters = {} for model in self.should_stop_pipeline: should_listen, output_parameters = model.detected( audio_data, output_parameters ) - # TODO: Add handling output parametrs for checking if should cancel if should_listen: - return False, False - return True, False + return True + return False def _transcription_function(self): - with self.transcription_lock: - samples = np.array(self.shared_samples) - print(samples) - self.shared_samples = [] + while self.running: + time.sleep(0.1) + # critical section for samples + with self.transcription_lock: + samples = np.array(self.shared_samples) + self.shared_samples = [] + # end critical section for samples + self.transcription_model.add_samples(samples) diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py index 20ec17c5..a7648b8a 100644 --- a/src/rai_asr/rai_asr/models/base.py +++ b/src/rai_asr/rai_asr/models/base.py @@ -36,8 +36,9 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): self.language = language @abstractmethod - def transcribe(self, data: NDArray[np.int16]) -> str: + def add_samples(self, data: NDArray[np.int16]): pass - def __call__(self, data: NDArray[np.int16]) -> str: - return self.transcribe(data) + @abstractmethod + def transcribe(self) -> str: + pass diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index 0a0b35b9..13d86c8a 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import cast + import numpy as np import whisper from numpy._typing import NDArray @@ -24,9 +26,22 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): super().__init__(model_name, sample_rate, language) self.whisper = whisper.load_model(self.model_name) - def transcribe(self, data: NDArray[np.int16]) -> str: - result = whisper.transcribe(self.whisper, data.astype(np.float32) / 32768.0) + self.samples = None + + def add_samples(self, data: NDArray[np.int16]): + normalized_data = data.astype(np.float32) / 32768.0 + self.samples = ( + np.concatenate([self.samples, normalized_data]) + if self.samples is not None + else data + ) + + def transcribe(self) -> str: + if self.samples is None: + raise ValueError("No samples to transcribe") + result = whisper.transcribe( + self.whisper, self.samples + ) # TODO: handling of additional transcribe arguments (perhaps in model init) transcription = result["text"] - # NOTE: this is only for type enforcement, doesn't need to work on runtime - assert isinstance(transcription, str) + transcription = cast(str, transcription) return transcription diff --git a/src/rai_asr/rai_asr/models/open_ai_whisper.py b/src/rai_asr/rai_asr/models/open_ai_whisper.py index 2706c786..a5734f20 100644 --- a/src/rai_asr/rai_asr/models/open_ai_whisper.py +++ b/src/rai_asr/rai_asr/models/open_ai_whisper.py @@ -36,10 +36,19 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): self.openai_client.audio.transcriptions.create, model=self.model_name, ) + self.samples = [] - def transcribe(self, data: NDArray[np.int16]) -> str: + def add_samples(self, data: NDArray[np.int16]): + normalized_data = data.astype(np.float32) / 32768.0 + self.samples = ( + np.concatenate([self.samples, normalized_data]) + if self.samples is not None + else data + ) + + def transcribe(self) -> str: with io.BytesIO() as temp_wav_buffer: - wavfile.write(temp_wav_buffer, self.sample_rate, data) + wavfile.write(temp_wav_buffer, self.sample_rate, self.samples) temp_wav_buffer.seek(0) temp_wav_buffer.name = "temp.wav" response = self.model(file=temp_wav_buffer, language=self.language) diff --git a/src/rai_asr/rai_asr/models/silero_vad.py b/src/rai_asr/rai_asr/models/silero_vad.py index 91211ef9..98415481 100644 --- a/src/rai_asr/rai_asr/models/silero_vad.py +++ b/src/rai_asr/rai_asr/models/silero_vad.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Tuple +from typing import Any, Literal, Tuple import numpy as np import torch @@ -22,7 +22,7 @@ class SileroVAD(BaseVoiceDetectionModel): - def __init__(self, sampling_rate=16000, threshold=0.5): + def __init__(self, sampling_rate: Literal[8000, 16000] = 16000, threshold=0.5): super(SileroVAD, self).__init__() self.model_name = "silero_vad" self.model, _ = torch.hub.load( From 5749111cae4846f6e07fa79b249ab0a5b9cae5ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Thu, 9 Jan 2025 15:38:08 +0100 Subject: [PATCH 06/16] feat: basic multithreading implementation of transcription --- src/rai/rai/agents/base.py | 4 - src/rai/rai/agents/voice_agent.py | 171 +++++++++++++------ src/rai_asr/rai_asr/models/base.py | 11 +- src/rai_asr/rai_asr/models/local_whisper.py | 24 ++- src/rai_asr/rai_asr/models/open_wake_word.py | 3 +- src/rai_asr/rai_asr/models/silero_vad.py | 1 - 6 files changed, 134 insertions(+), 80 deletions(-) diff --git a/src/rai/rai/agents/base.py b/src/rai/rai/agents/base.py index 285691c6..c2dd4fe5 100644 --- a/src/rai/rai/agents/base.py +++ b/src/rai/rai/agents/base.py @@ -27,10 +27,6 @@ def __init__( connectors = {} self.connectors: dict[str, BaseConnector] = connectors - @abstractmethod - def setup(self, *args, **kwargs): - pass - @abstractmethod def run(self, *args, **kwargs): pass diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py index 2657bfdd..27caf5f1 100644 --- a/src/rai/rai/agents/voice_agent.py +++ b/src/rai/rai/agents/voice_agent.py @@ -14,46 +14,61 @@ import time -from threading import Lock, Thread -from typing import Any, List, cast +from threading import Event, Lock, Thread +from typing import Any, List, TypedDict +from uuid import uuid4 import numpy as np from numpy.typing import NDArray from rai.agents.base import BaseAgent from rai.communication import AudioInputDeviceConfig, StreamingAudioInputDevice -from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel +from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel -class VoiceRecognitionAgent(BaseAgent): - def __init__(self): - super().__init__(connectors={"microphone": StreamingAudioInputDevice()}) - self.should_record_pipeline: List[BaseVoiceDetectionModel] = [] - self.should_stop_pipeline: List[BaseVoiceDetectionModel] = [] - self.transcription_lock = Lock() - self.shared_samples = [] - self.recording_started = False - self.ran_setup = False +class ThreadData(TypedDict): + thread: Thread + event: Event + transcription: str - def __call__(self): - self.run() - def setup( +class VoiceRecognitionAgent(BaseAgent): + def __init__( self, microphone_device_id: int, # TODO: Change to name based instead of id based identification microphone_config: AudioInputDeviceConfig, transcription_model: BaseTranscriptionModel, + vad: BaseVoiceDetectionModel, + grace_period: float = 1.0, ): - self.connectors["microphone"] = cast( - StreamingAudioInputDevice, self.connectors["microphone"] + microphone = StreamingAudioInputDevice() + microphone.configure_device( + target=str(microphone_device_id), config=microphone_config ) + super().__init__(connectors={"microphone": microphone}) self.microphone_device_id = str(microphone_device_id) - self.connectors["microphone"].configure_device( - target=self.microphone_device_id, config=microphone_config - ) + self.should_record_pipeline: List[BaseVoiceDetectionModel] = [] + self.should_stop_pipeline: List[BaseVoiceDetectionModel] = [] + self.transcription_model = transcription_model - self.ran_setup = True - self.running = False + self.transcription_lock = Lock() + + self.vad: BaseVoiceDetectionModel = vad + + self.grace_period = grace_period + self.grace_period_start = 0 + + self.recording_started = False + self.ran_setup = False + + self.sample_buffer = [] + self.sample_buffer_lock = Lock() + self.active_thread = "" + self.transcription_threads: dict[str, ThreadData] = {} + self.buffer_reminders: dict[str, list[NDArray]] = {} + + def __call__(self): + self.run() def add_detection_model( self, model: BaseVoiceDetectionModel, pipeline: str = "record" @@ -70,49 +85,93 @@ def run(self): self.listener_handle = self.connectors["microphone"].start_action( self.microphone_device_id, self.on_new_sample ) - self.transcription_thread = Thread(target=self._transcription_function) - self.transcription_thread.start() def stop(self): self.running = False self.connectors["microphone"].terminate_action(self.listener_handle) - self.transcription_thread.join() + to_finish = list(self.transcription_threads.keys()) + while len(to_finish) > 0: + for thread_id in self.transcription_threads: + if self.transcription_threads[thread_id]["event"].is_set(): + self.transcription_threads[thread_id]["thread"].join() + to_finish.remove(thread_id) + else: + print(f"Waiting for transcription of {thread_id} to finish") def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): - should_stop = self.should_stop_recording(indata) - if self.should_start_recording(indata): + sample_time = time.time() + with self.sample_buffer_lock: + self.sample_buffer.append(indata) + if not self.recording_started and len(self.sample_buffer) > 5: + self.sample_buffer = self.sample_buffer[-5:] + + voice_detected, output_parameters = self.vad.detected(indata, {}) + should_record = False + # TODO: second condition is temporary + if voice_detected and not self.recording_started: + should_record = self.should_record(indata, output_parameters) + + if should_record: + print("Start recording") self.recording_started = True - if self.recording_started and not should_stop: - with self.transcription_lock: - self.shared_samples.extend(indata) + thread_id = str(uuid4())[0:8] + transcription_thread = Thread( + target=self.transcription_thread, + args=[thread_id], + ) + transcription_finished = Event() + self.active_thread = thread_id + transcription_thread.start() + self.transcription_threads[thread_id] = { + "thread": transcription_thread, + "event": transcription_finished, + "transcription": "", + } + + if voice_detected: + self.grace_period_start = sample_time - def should_start_recording(self, audio_data: NDArray[np.int16]) -> bool: - output_parameters = {} + if ( + self.recording_started + and sample_time - self.grace_period_start > self.grace_period + ): + print("Stop recording") + self.recording_started = False + self.grace_period_start = 0 + with self.sample_buffer_lock: + self.buffer_reminders[self.active_thread] = self.sample_buffer + self.sample_buffer = [] + self.active_thread = "" + + def should_record( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> bool: for model in self.should_record_pipeline: - should_listen, output_parameters = model.detected( - audio_data, output_parameters - ) - print(should_listen, output_parameters) - if not should_listen: - return False - return True - - def should_stop_recording(self, audio_data: NDArray[np.int16]) -> bool: - output_parameters = {} - for model in self.should_stop_pipeline: - should_listen, output_parameters = model.detected( - audio_data, output_parameters - ) - if should_listen: + detected, output = model.detected(audio_data, input_parameters) + print(f"Detected: {detected}: {output}") + if detected: return True return False - def _transcription_function(self): - while self.running: - time.sleep(0.1) - # critical section for samples - with self.transcription_lock: - samples = np.array(self.shared_samples) - self.shared_samples = [] - # end critical section for samples - self.transcription_model.add_samples(samples) + def transcription_thread(self, identifier: str): + with self.transcription_lock: + while self.active_thread == identifier: + with self.sample_buffer_lock: + if len(self.sample_buffer) == 0: + continue + audio_data = self.sample_buffer.copy() + self.sample_buffer = [] + audio_data = np.concatenate(audio_data) + self.transcription_model.transcribe(audio_data) + + # transciption of the reminder of the buffer + with self.sample_buffer_lock: + if identifier in self.buffer_reminders: + audio_data = self.buffer_reminders[identifier] + audio_data = np.concatenate(audio_data) + self.transcription_model.transcribe(audio_data) + del self.buffer_reminders[identifier] + transcription = self.transcription_model.consume_transcription() + self.transcription_threads[identifier]["transcription"] = transcription + self.transcription_threads[identifier]["event"].set() + # TODO: sending the transcription diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py index a7648b8a..aa02d25f 100644 --- a/src/rai_asr/rai_asr/models/base.py +++ b/src/rai_asr/rai_asr/models/base.py @@ -35,10 +35,13 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): self.sample_rate = sample_rate self.language = language - @abstractmethod - def add_samples(self, data: NDArray[np.int16]): - pass + self.latest_transcription = "" + + def consume_transcription(self) -> str: + ret = self.latest_transcription + self.latest_transcription = "" + return ret @abstractmethod - def transcribe(self) -> str: + def transcribe(self, data: NDArray[np.int16]): pass diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index 13d86c8a..84681c0e 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -15,6 +15,7 @@ from typing import cast import numpy as np +import torch import whisper from numpy._typing import NDArray @@ -24,24 +25,21 @@ class LocalWhisper(BaseTranscriptionModel): def __init__(self, model_name: str, sample_rate: int, language: str = "en"): super().__init__(model_name, sample_rate, language) - self.whisper = whisper.load_model(self.model_name) + if torch.cuda.is_available(): + print("Using CUDA") + self.whisper = whisper.load_model(self.model_name, device="cuda") + else: + self.whisper = whisper.load_model(self.model_name) self.samples = None - def add_samples(self, data: NDArray[np.int16]): + def transcribe(self, data: NDArray[np.int16]): normalized_data = data.astype(np.float32) / 32768.0 - self.samples = ( - np.concatenate([self.samples, normalized_data]) - if self.samples is not None - else data - ) - - def transcribe(self) -> str: - if self.samples is None: - raise ValueError("No samples to transcribe") + print("Starting transcription") result = whisper.transcribe( - self.whisper, self.samples + self.whisper, normalized_data ) # TODO: handling of additional transcribe arguments (perhaps in model init) + print("Finished transcription") transcription = result["text"] transcription = cast(str, transcription) - return transcription + self.latest_transcription += transcription diff --git a/src/rai_asr/rai_asr/models/open_wake_word.py b/src/rai_asr/rai_asr/models/open_wake_word.py index 3daadbcf..d68dbf8d 100644 --- a/src/rai_asr/rai_asr/models/open_wake_word.py +++ b/src/rai_asr/rai_asr/models/open_wake_word.py @@ -37,12 +37,11 @@ def __init__(self, wake_word_model_path: str, threshold: float = 0.5): def detected( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: - print(len(audio_data)) predictions = self.model.predict(audio_data) ret = input_parameters.copy() ret.update({self.model_name: {"predictions": predictions}}) for key, value in predictions.items(): if value > self.threshold: - self.model.reset() + # self.model.reset() return True, ret return False, ret diff --git a/src/rai_asr/rai_asr/models/silero_vad.py b/src/rai_asr/rai_asr/models/silero_vad.py index 98415481..34ad0135 100644 --- a/src/rai_asr/rai_asr/models/silero_vad.py +++ b/src/rai_asr/rai_asr/models/silero_vad.py @@ -57,6 +57,5 @@ def detected( ).item() ret = input_parameters.copy() ret.update({self.model_name: {"vad_confidence": vad_confidence}}) - self.model.reset_states() # NOTE: see streaming example at the bottom https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies return vad_confidence > self.threshold, ret From 1e948a22b5191a513ca83f3882407d8d00928c6e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Thu, 9 Jan 2025 21:14:06 +0100 Subject: [PATCH 07/16] feat: functional multithreading transcription --- src/rai/rai/agents/voice_agent.py | 32 +++++++++++++------ .../communication/sound_device_connector.py | 9 ++---- src/rai_asr/rai_asr/models/local_whisper.py | 31 +++++++++++++++--- .../test_sound_device_connector.py | 1 - 4 files changed, 53 insertions(+), 20 deletions(-) diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py index 27caf5f1..d0508321 100644 --- a/src/rai/rai/agents/voice_agent.py +++ b/src/rai/rai/agents/voice_agent.py @@ -13,9 +13,10 @@ # limitations under the License. +import logging import time from threading import Event, Lock, Thread -from typing import Any, List, TypedDict +from typing import Any, List, Optional, TypedDict from uuid import uuid4 import numpy as np @@ -40,7 +41,12 @@ def __init__( transcription_model: BaseTranscriptionModel, vad: BaseVoiceDetectionModel, grace_period: float = 1.0, + logger: Optional[logging.Logger] = None, ): + if logger is None: + self.logger = logging.getLogger(__name__) + else: + self.logger = logger microphone = StreamingAudioInputDevice() microphone.configure_device( target=str(microphone_device_id), config=microphone_config @@ -87,16 +93,20 @@ def run(self): ) def stop(self): + self.logger.info("Stopping voice agent") self.running = False self.connectors["microphone"].terminate_action(self.listener_handle) - to_finish = list(self.transcription_threads.keys()) - while len(to_finish) > 0: + to_finish = len(list(self.transcription_threads.keys())) + while to_finish > 0: for thread_id in self.transcription_threads: if self.transcription_threads[thread_id]["event"].is_set(): self.transcription_threads[thread_id]["thread"].join() - to_finish.remove(thread_id) + to_finish -= 1 else: - print(f"Waiting for transcription of {thread_id} to finish") + self.logger.info( + f"Waiting for transcription of {thread_id} to finish..." + ) + self.logger.info("Voice agent stopped") def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): sample_time = time.time() @@ -112,7 +122,7 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): should_record = self.should_record(indata, output_parameters) if should_record: - print("Start recording") + self.logger.info("starting recording...") self.recording_started = True thread_id = str(uuid4())[0:8] transcription_thread = Thread( @@ -129,13 +139,14 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): } if voice_detected: + self.logger.debug("Voice detected... resetting grace period") self.grace_period_start = sample_time if ( self.recording_started and sample_time - self.grace_period_start > self.grace_period ): - print("Stop recording") + self.logger.info("Grace period ended... stopping recording") self.recording_started = False self.grace_period_start = 0 with self.sample_buffer_lock: @@ -148,12 +159,12 @@ def should_record( ) -> bool: for model in self.should_record_pipeline: detected, output = model.detected(audio_data, input_parameters) - print(f"Detected: {detected}: {output}") if detected: return True return False def transcription_thread(self, identifier: str): + self.logger.info(f"transcription thread {identifier} started") with self.transcription_lock: while self.active_thread == identifier: with self.sample_buffer_lock: @@ -171,7 +182,10 @@ def transcription_thread(self, identifier: str): audio_data = np.concatenate(audio_data) self.transcription_model.transcribe(audio_data) del self.buffer_reminders[identifier] + # self.transcription_model.save_wav(f"{identifier}.wav") transcription = self.transcription_model.consume_transcription() self.transcription_threads[identifier]["transcription"] = transcription self.transcription_threads[identifier]["event"].set() - # TODO: sending the transcription + # TODO: sending the transcription once https://github.com/RobotecAI/rai/pull/360 is merged + self.logger.info(f"transcription thread {identifier} finished") + print(transcription) diff --git a/src/rai/rai/communication/sound_device_connector.py b/src/rai/rai/communication/sound_device_connector.py index 449c0890..40edad01 100644 --- a/src/rai/rai/communication/sound_device_connector.py +++ b/src/rai/rai/communication/sound_device_connector.py @@ -30,7 +30,6 @@ def __init__(self, msg: str): class AudioInputDeviceConfig(TypedDict): block_size: int consumer_sampling_rate: int - target_sampling_rate: int dtype: str device_number: Optional[int] @@ -44,7 +43,6 @@ class ConfiguredAudioInputDevice: sample_rate (int): Device sample rate consumer_sampling_rate (int): The sampling rate of the consumer window_size_samples (int): The size of the window in samples - target_sampling_rate (int): The target sampling rate dtype (str): The data type of the audio samples """ @@ -58,7 +56,6 @@ def __init__(self, config: AudioInputDeviceConfig): self.window_size_samples = int( config["block_size"] * self.sample_rate / config["consumer_sampling_rate"] ) - self.target_sampling_rate = int(config["target_sampling_rate"]) self.dtype = config["dtype"] @@ -132,9 +129,9 @@ def start_action( def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags): indata = indata.flatten() - sample_time_length = len(indata) / target_device.target_sampling_rate - if target_device.sample_rate != target_device.target_sampling_rate: - indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore + sample_time_length = len(indata) / target_device.sample_rate + if target_device.sample_rate != target_device.consumer_sampling_rate: + indata = resample(indata, int(sample_time_length * target_device.consumer_sampling_rate)) # type: ignore flag_dict = { "input_overflow": status.input_overflow, "input_underflow": status.input_underflow, diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index 84681c0e..77756c3c 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -26,20 +26,43 @@ class LocalWhisper(BaseTranscriptionModel): def __init__(self, model_name: str, sample_rate: int, language: str = "en"): super().__init__(model_name, sample_rate, language) if torch.cuda.is_available(): - print("Using CUDA") self.whisper = whisper.load_model(self.model_name, device="cuda") else: self.whisper = whisper.load_model(self.model_name) - self.samples = None + # TODO: remove sample storage before PR is merged, this is just to enable saving wav files for debugging + # self.samples = None + + def consume_transcription(self) -> str: + ret = super().consume_transcription() + # self.samples = None + return ret + + # def save_wav(self, output_filename: str): + # assert self.samples is not None, "No samples to save" + # combined_samples = self.samples + # if combined_samples.dtype.kind == "f": + # combined_samples = np.clip(combined_samples, -1.0, 1.0) + # combined_samples = (combined_samples * 32767).astype(np.int16) + # elif combined_samples.dtype != np.int16: + # combined_samples = combined_samples.astype(np.int16) + + # with wave.open(output_filename, "wb") as wav_file: + # n_channels = 1 + # sampwidth = 2 + # wav_file.setnchannels(n_channels) + # wav_file.setsampwidth(sampwidth) + # wav_file.setframerate(self.sample_rate) + # wav_file.writeframes(combined_samples.tobytes()) def transcribe(self, data: NDArray[np.int16]): + # self.samples = ( + # np.concatenate((self.samples, data)) if self.samples is not None else data + # ) normalized_data = data.astype(np.float32) / 32768.0 - print("Starting transcription") result = whisper.transcribe( self.whisper, normalized_data ) # TODO: handling of additional transcribe arguments (perhaps in model init) - print("Finished transcription") transcription = result["text"] transcription = cast(str, transcription) self.latest_transcription += transcription diff --git a/tests/communication/test_sound_device_connector.py b/tests/communication/test_sound_device_connector.py index 6eb2b0bd..b98ca255 100644 --- a/tests/communication/test_sound_device_connector.py +++ b/tests/communication/test_sound_device_connector.py @@ -31,7 +31,6 @@ def device_config(): return { "block_size": 1024, "consumer_sampling_rate": 44100, - "target_sampling_rate": 16000, "dtype": "float32", } From 7dc4b43f7d67fba6cd35c405bd58ef37764efafb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Tue, 21 Jan 2025 17:18:03 +0100 Subject: [PATCH 08/16] feat: integrate with ros2 connector --- src/rai/rai/agents/voice_agent.py | 26 +++++++++++++++----- src/rai/rai/communication/__init__.py | 3 +++ src/rai/rai/communication/ros2/connectors.py | 10 ++++++-- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py index d0508321..c3d17983 100644 --- a/src/rai/rai/agents/voice_agent.py +++ b/src/rai/rai/agents/voice_agent.py @@ -23,7 +23,12 @@ from numpy.typing import NDArray from rai.agents.base import BaseAgent -from rai.communication import AudioInputDeviceConfig, StreamingAudioInputDevice +from rai.communication import ( + AudioInputDeviceConfig, + ROS2ARIConnector, + ROS2ARIMessage, + StreamingAudioInputDevice, +) from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel @@ -38,6 +43,7 @@ def __init__( self, microphone_device_id: int, # TODO: Change to name based instead of id based identification microphone_config: AudioInputDeviceConfig, + ros2_name: str, transcription_model: BaseTranscriptionModel, vad: BaseVoiceDetectionModel, grace_period: float = 1.0, @@ -51,7 +57,8 @@ def __init__( microphone.configure_device( target=str(microphone_device_id), config=microphone_config ) - super().__init__(connectors={"microphone": microphone}) + ros2_connector = ROS2ARIConnector(ros2_name) + super().__init__(connectors={"microphone": microphone, "ros2": ros2_connector}) self.microphone_device_id = str(microphone_device_id) self.should_record_pipeline: List[BaseVoiceDetectionModel] = [] self.should_stop_pipeline: List[BaseVoiceDetectionModel] = [] @@ -89,7 +96,10 @@ def add_detection_model( def run(self): self.running = True self.listener_handle = self.connectors["microphone"].start_action( - self.microphone_device_id, self.on_new_sample + action_data=None, + target=self.microphone_device_id, + on_feedback=self.on_new_sample, + on_done=lambda: None, ) def stop(self): @@ -184,8 +194,12 @@ def transcription_thread(self, identifier: str): del self.buffer_reminders[identifier] # self.transcription_model.save_wav(f"{identifier}.wav") transcription = self.transcription_model.consume_transcription() + print("Transcription: ", transcription) + self.connectors["ros2"].send_message( + ROS2ARIMessage( + {"data": transcription}, {"msg_type": "std_msgs/msg/String"} + ), + "/from_human", + ) self.transcription_threads[identifier]["transcription"] = transcription self.transcription_threads[identifier]["event"].set() - # TODO: sending the transcription once https://github.com/RobotecAI/rai/pull/360 is merged - self.logger.info(f"transcription thread {identifier} finished") - print(transcription) diff --git a/src/rai/rai/communication/__init__.py b/src/rai/rai/communication/__init__.py index c49ac9e4..5134c2d9 100644 --- a/src/rai/rai/communication/__init__.py +++ b/src/rai/rai/communication/__init__.py @@ -15,6 +15,7 @@ from .ari_connector import ARIConnector, ARIMessage from .base_connector import BaseConnector, BaseMessage from .hri_connector import HRIConnector, HRIMessage, HRIPayload +from .ros2.connectors import ROS2ARIConnector, ROS2ARIMessage from .sound_device_connector import ( AudioInputDeviceConfig, SoundDeviceError, @@ -29,6 +30,8 @@ "HRIConnector", "HRIMessage", "HRIPayload", + "ROS2ARIConnector", + "ROS2ARIMessage", "StreamingAudioInputDevice", "SoundDeviceError", "AudioInputDeviceConfig", diff --git a/src/rai/rai/communication/ros2/connectors.py b/src/rai/rai/communication/ros2/connectors.py index 9ee24c11..e51e4b47 100644 --- a/src/rai/rai/communication/ros2/connectors.py +++ b/src/rai/rai/communication/ros2/connectors.py @@ -14,7 +14,7 @@ import threading import uuid -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, TypedDict from rclpy.executors import MultiThreadedExecutor from rclpy.node import Node @@ -23,8 +23,14 @@ from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI +class ROS2ARIPayload(TypedDict): + data: Any + + class ROS2ARIMessage(ARIMessage): - def __init__(self, payload: Any, metadata: Optional[Dict[str, Any]] = None): + def __init__( + self, payload: ROS2ARIPayload, metadata: Optional[Dict[str, Any]] = None + ): super().__init__(payload, metadata) From 928a9ffa071d2c93feeadb0c0f62d5ef1640450d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Thu, 23 Jan 2025 14:51:18 +0100 Subject: [PATCH 09/16] feat: working streaming ASR --- src/rai/rai/agents/voice_agent.py | 96 ++++++++++++------- src/rai_asr/rai_asr/models/base.py | 7 +- src/rai_asr/rai_asr/models/local_whisper.py | 16 ++-- src/rai_asr/rai_asr/models/open_ai_whisper.py | 14 +-- 4 files changed, 73 insertions(+), 60 deletions(-) diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py index c3d17983..0a8ae457 100644 --- a/src/rai/rai/agents/voice_agent.py +++ b/src/rai/rai/agents/voice_agent.py @@ -36,6 +36,7 @@ class ThreadData(TypedDict): thread: Thread event: Event transcription: str + joined: bool class VoiceRecognitionAgent(BaseAgent): @@ -78,7 +79,7 @@ def __init__( self.sample_buffer_lock = Lock() self.active_thread = "" self.transcription_threads: dict[str, ThreadData] = {} - self.buffer_reminders: dict[str, list[NDArray]] = {} + self.transcription_buffers: dict[str, list[NDArray]] = {} def __call__(self): self.run() @@ -106,12 +107,13 @@ def stop(self): self.logger.info("Stopping voice agent") self.running = False self.connectors["microphone"].terminate_action(self.listener_handle) - to_finish = len(list(self.transcription_threads.keys())) - while to_finish > 0: + while not all( + [thread["joined"] for thread in self.transcription_threads.values()] + ): for thread_id in self.transcription_threads: if self.transcription_threads[thread_id]["event"].is_set(): self.transcription_threads[thread_id]["thread"].join() - to_finish -= 1 + self.transcription_threads[thread_id]["joined"] = True else: self.logger.info( f"Waiting for transcription of {thread_id} to finish..." @@ -125,6 +127,12 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): if not self.recording_started and len(self.sample_buffer) > 5: self.sample_buffer = self.sample_buffer[-5:] + # attempt to join finished threads: + for thread_id in self.transcription_threads: + if self.transcription_threads[thread_id]["event"].is_set(): + self.transcription_threads[thread_id]["thread"].join() + self.transcription_threads[thread_id]["joined"] = True + voice_detected, output_parameters = self.vad.detected(indata, {}) should_record = False # TODO: second condition is temporary @@ -141,11 +149,11 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): ) transcription_finished = Event() self.active_thread = thread_id - transcription_thread.start() self.transcription_threads[thread_id] = { "thread": transcription_thread, "event": transcription_finished, "transcription": "", + "joined": False, } if voice_detected: @@ -156,12 +164,15 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): self.recording_started and sample_time - self.grace_period_start > self.grace_period ): - self.logger.info("Grace period ended... stopping recording") + self.logger.info( + "Grace period ended... stopping recording, starting transcription" + ) self.recording_started = False self.grace_period_start = 0 with self.sample_buffer_lock: - self.buffer_reminders[self.active_thread] = self.sample_buffer + self.transcription_buffers[self.active_thread] = self.sample_buffer self.sample_buffer = [] + self.transcription_threads[self.active_thread]["thread"].start() self.active_thread = "" def should_record( @@ -175,31 +186,46 @@ def should_record( def transcription_thread(self, identifier: str): self.logger.info(f"transcription thread {identifier} started") - with self.transcription_lock: - while self.active_thread == identifier: - with self.sample_buffer_lock: - if len(self.sample_buffer) == 0: - continue - audio_data = self.sample_buffer.copy() - self.sample_buffer = [] - audio_data = np.concatenate(audio_data) - self.transcription_model.transcribe(audio_data) - - # transciption of the reminder of the buffer - with self.sample_buffer_lock: - if identifier in self.buffer_reminders: - audio_data = self.buffer_reminders[identifier] - audio_data = np.concatenate(audio_data) - self.transcription_model.transcribe(audio_data) - del self.buffer_reminders[identifier] - # self.transcription_model.save_wav(f"{identifier}.wav") - transcription = self.transcription_model.consume_transcription() - print("Transcription: ", transcription) - self.connectors["ros2"].send_message( - ROS2ARIMessage( - {"data": transcription}, {"msg_type": "std_msgs/msg/String"} - ), - "/from_human", - ) - self.transcription_threads[identifier]["transcription"] = transcription - self.transcription_threads[identifier]["event"].set() + audio_data = np.concatenate(self.transcription_buffers[identifier]) + with self.transcription_lock: # this is only necessary for the local model... TODO: fix this somehow + transcription = self.transcription_model.transcribe(audio_data) + self.connectors["ros2"].send_message( + ROS2ARIMessage( + {"data": transcription}, {"msg_type": "std_msgs/msg/String"} + ), + "/from_human", + ) + self.transcription_threads[identifier]["transcription"] = transcription + self.transcription_threads[identifier]["event"].set() + + # with self.transcription_lock: + # while self.active_thread == identifier: + # with self.sample_buffer_lock: + # if len(self.sample_buffer) == 0: + # continue + # audio_data = self.sample_buffer.copy() + # self.sample_buffer = [] + # audio_data = np.concatenate(audio_data) + # with self.transcription_lock: + # self.transcription_model.transcribe(audio_data) + + # # transciption of the reminder of the buffer + # with self.sample_buffer_lock: + # if identifier in self.transcription_buffers: + # audio_data = self.transcription_buffers[identifier] + # audio_data = np.concatenate(audio_data) + # with self.transcription_lock: + # self.transcription_model.transcribe(audio_data) + # del self.transcription_buffers[identifier] + # # self.transcription_model.save_wav(f"{identifier}.wav") + # with self.transcription_lock: + # transcription = self.transcription_model.consume_transcription() + # self.logger.info(f"Transcription: {transcription}") + # self.connectors["ros2"].send_message( + # ROS2ARIMessage( + # {"data": transcription}, {"msg_type": "std_msgs/msg/String"} + # ), + # "/from_human", + # ) + # self.transcription_threads[identifier]["transcription"] = transcription + # self.transcription_threads[identifier]["event"].set() diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py index aa02d25f..d45a6176 100644 --- a/src/rai_asr/rai_asr/models/base.py +++ b/src/rai_asr/rai_asr/models/base.py @@ -37,11 +37,6 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): self.latest_transcription = "" - def consume_transcription(self) -> str: - ret = self.latest_transcription - self.latest_transcription = "" - return ret - @abstractmethod - def transcribe(self, data: NDArray[np.int16]): + def transcribe(self, data: NDArray[np.int16]) -> str: pass diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index 77756c3c..571f0233 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import cast import numpy as np @@ -30,14 +31,10 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): else: self.whisper = whisper.load_model(self.model_name) + self.logger = logging.getLogger(__name__) # TODO: remove sample storage before PR is merged, this is just to enable saving wav files for debugging # self.samples = None - def consume_transcription(self) -> str: - ret = super().consume_transcription() - # self.samples = None - return ret - # def save_wav(self, output_filename: str): # assert self.samples is not None, "No samples to save" # combined_samples = self.samples @@ -55,14 +52,13 @@ def consume_transcription(self) -> str: # wav_file.setframerate(self.sample_rate) # wav_file.writeframes(combined_samples.tobytes()) - def transcribe(self, data: NDArray[np.int16]): - # self.samples = ( - # np.concatenate((self.samples, data)) if self.samples is not None else data - # ) + def transcribe(self, data: NDArray[np.int16]) -> str: normalized_data = data.astype(np.float32) / 32768.0 result = whisper.transcribe( self.whisper, normalized_data ) # TODO: handling of additional transcribe arguments (perhaps in model init) transcription = result["text"] + self.logger.info("transcription: %s", transcription) transcription = cast(str, transcription) - self.latest_transcription += transcription + self.latest_transcription = transcription + return transcription diff --git a/src/rai_asr/rai_asr/models/open_ai_whisper.py b/src/rai_asr/rai_asr/models/open_ai_whisper.py index a5734f20..a5d30ba7 100644 --- a/src/rai_asr/rai_asr/models/open_ai_whisper.py +++ b/src/rai_asr/rai_asr/models/open_ai_whisper.py @@ -13,6 +13,7 @@ # limitations under the License. import io +import logging import os from functools import partial @@ -36,21 +37,16 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): self.openai_client.audio.transcriptions.create, model=self.model_name, ) + self.logger = logging.getLogger(__name__) self.samples = [] - def add_samples(self, data: NDArray[np.int16]): + def transcribe(self, data: NDArray[np.int16]) -> str: normalized_data = data.astype(np.float32) / 32768.0 - self.samples = ( - np.concatenate([self.samples, normalized_data]) - if self.samples is not None - else data - ) - - def transcribe(self) -> str: with io.BytesIO() as temp_wav_buffer: - wavfile.write(temp_wav_buffer, self.sample_rate, self.samples) + wavfile.write(temp_wav_buffer, self.sample_rate, normalized_data) temp_wav_buffer.seek(0) temp_wav_buffer.name = "temp.wav" response = self.model(file=temp_wav_buffer, language=self.language) transcription = response.text + self.logger.info("transcription: %s", transcription) return transcription From 05219f9b428f58ef78d83cbb1b7ce78ac2ffa355 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Thu, 23 Jan 2025 16:03:14 +0100 Subject: [PATCH 10/16] chore: cleanup --- src/rai/rai/agents/voice_agent.py | 32 --------------------- src/rai_asr/rai_asr/models/local_whisper.py | 19 ------------ 2 files changed, 51 deletions(-) diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py index 0a8ae457..3b7ad90a 100644 --- a/src/rai/rai/agents/voice_agent.py +++ b/src/rai/rai/agents/voice_agent.py @@ -197,35 +197,3 @@ def transcription_thread(self, identifier: str): ) self.transcription_threads[identifier]["transcription"] = transcription self.transcription_threads[identifier]["event"].set() - - # with self.transcription_lock: - # while self.active_thread == identifier: - # with self.sample_buffer_lock: - # if len(self.sample_buffer) == 0: - # continue - # audio_data = self.sample_buffer.copy() - # self.sample_buffer = [] - # audio_data = np.concatenate(audio_data) - # with self.transcription_lock: - # self.transcription_model.transcribe(audio_data) - - # # transciption of the reminder of the buffer - # with self.sample_buffer_lock: - # if identifier in self.transcription_buffers: - # audio_data = self.transcription_buffers[identifier] - # audio_data = np.concatenate(audio_data) - # with self.transcription_lock: - # self.transcription_model.transcribe(audio_data) - # del self.transcription_buffers[identifier] - # # self.transcription_model.save_wav(f"{identifier}.wav") - # with self.transcription_lock: - # transcription = self.transcription_model.consume_transcription() - # self.logger.info(f"Transcription: {transcription}") - # self.connectors["ros2"].send_message( - # ROS2ARIMessage( - # {"data": transcription}, {"msg_type": "std_msgs/msg/String"} - # ), - # "/from_human", - # ) - # self.transcription_threads[identifier]["transcription"] = transcription - # self.transcription_threads[identifier]["event"].set() diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index 571f0233..377b3316 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -32,25 +32,6 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): self.whisper = whisper.load_model(self.model_name) self.logger = logging.getLogger(__name__) - # TODO: remove sample storage before PR is merged, this is just to enable saving wav files for debugging - # self.samples = None - - # def save_wav(self, output_filename: str): - # assert self.samples is not None, "No samples to save" - # combined_samples = self.samples - # if combined_samples.dtype.kind == "f": - # combined_samples = np.clip(combined_samples, -1.0, 1.0) - # combined_samples = (combined_samples * 32767).astype(np.int16) - # elif combined_samples.dtype != np.int16: - # combined_samples = combined_samples.astype(np.int16) - - # with wave.open(output_filename, "wb") as wav_file: - # n_channels = 1 - # sampwidth = 2 - # wav_file.setnchannels(n_channels) - # wav_file.setsampwidth(sampwidth) - # wav_file.setframerate(self.sample_rate) - # wav_file.writeframes(combined_samples.tobytes()) def transcribe(self, data: NDArray[np.int16]) -> str: normalized_data = data.astype(np.float32) / 32768.0 From 05cae9f13c7e022286fde0354a9e0e3cbf7a3b4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Thu, 23 Jan 2025 17:18:39 +0100 Subject: [PATCH 11/16] fix: tests --- tests/communication/test_sound_device_connector.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/communication/test_sound_device_connector.py b/tests/communication/test_sound_device_connector.py index b98ca255..a98e251a 100644 --- a/tests/communication/test_sound_device_connector.py +++ b/tests/communication/test_sound_device_connector.py @@ -56,7 +56,6 @@ def test_configure( audio_input_device.configred_devices[device_id].consumer_sampling_rate == 44100 ) assert audio_input_device.configred_devices[device_id].window_size_samples == 1024 - assert audio_input_device.configred_devices[device_id].target_sampling_rate == 16000 assert audio_input_device.configred_devices[device_id].dtype == "float32" From dbe1ce4b71d78dff9bc93fd08c4b560fcda39dbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 24 Jan 2025 16:16:05 +0100 Subject: [PATCH 12/16] fix: change typing to more generic --- src/rai/rai/communication/ros2/connectors.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/rai/rai/communication/ros2/connectors.py b/src/rai/rai/communication/ros2/connectors.py index e51e4b47..9ee24c11 100644 --- a/src/rai/rai/communication/ros2/connectors.py +++ b/src/rai/rai/communication/ros2/connectors.py @@ -14,7 +14,7 @@ import threading import uuid -from typing import Any, Callable, Dict, Optional, TypedDict +from typing import Any, Callable, Dict, Optional from rclpy.executors import MultiThreadedExecutor from rclpy.node import Node @@ -23,14 +23,8 @@ from rai.communication.ros2.api import ROS2ActionAPI, ROS2ServiceAPI, ROS2TopicAPI -class ROS2ARIPayload(TypedDict): - data: Any - - class ROS2ARIMessage(ARIMessage): - def __init__( - self, payload: ROS2ARIPayload, metadata: Optional[Dict[str, Any]] = None - ): + def __init__(self, payload: Any, metadata: Optional[Dict[str, Any]] = None): super().__init__(payload, metadata) From 7a2601f41913014f6710f70edeed8f84355188a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 24 Jan 2025 16:21:15 +0100 Subject: [PATCH 13/16] fix: remove model resetting --- src/rai_asr/rai_asr/models/open_wake_word.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rai_asr/rai_asr/models/open_wake_word.py b/src/rai_asr/rai_asr/models/open_wake_word.py index d68dbf8d..5c0d9e74 100644 --- a/src/rai_asr/rai_asr/models/open_wake_word.py +++ b/src/rai_asr/rai_asr/models/open_wake_word.py @@ -42,6 +42,6 @@ def detected( ret.update({self.model_name: {"predictions": predictions}}) for key, value in predictions.items(): if value > self.threshold: - # self.model.reset() + self.model.reset() return True, ret return False, ret From 674dcd9f47a2f2fd386d498447cdd6d6a63f0b0b Mon Sep 17 00:00:00 2001 From: Maciej Majek Date: Fri, 24 Jan 2025 12:23:18 +0100 Subject: [PATCH 14/16] feat: add FasterWhisper --- poetry.lock | 75 ++++++++++++++++++--- pyproject.toml | 1 + src/rai_asr/rai_asr/models/local_whisper.py | 17 +++++ 3 files changed, 84 insertions(+), 9 deletions(-) diff --git a/poetry.lock b/poetry.lock index 31034f81..91c2abbc 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "absl-py" @@ -893,6 +893,40 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.11.1)", "types-Pil test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] test-no-images = ["pytest", "pytest-cov", "pytest-rerunfailures", "pytest-xdist", "wurlitzer"] +[[package]] +name = "ctranslate2" +version = "4.5.0" +description = "Fast inference engine for Transformer models" +optional = false +python-versions = ">=3.8" +files = [ + {file = "ctranslate2-4.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:241da685f8f7cb10b7afceeb3d879f778b56e6a1d55fc2964ddc949c80c9c7bb"}, + {file = "ctranslate2-4.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5328ec73b430ba1a99a85bc3b038291e7bbedc0c9987b354b3c8ca395a3b7e06"}, + {file = "ctranslate2-4.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b97ee9b15f75f84c35827df97ebe9c676f96c2e5118a2ed4d3efcf3c3e04a599"}, + {file = "ctranslate2-4.5.0-cp310-cp310-win_amd64.whl", hash = "sha256:5d9ec0a201d3c33ada1bb00929b3ff3d80642b34ca0d94465556dfa197d127c4"}, + {file = "ctranslate2-4.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1bc072da977abdd4b09f0d50a45de745818a247608aa3f2865ef9a579ff11851"}, + {file = "ctranslate2-4.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4c56ccf1aa723ba85f4ea56b4d945dc7d2ea7f074b5eb716c85be0c8e0311c24"}, + {file = "ctranslate2-4.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89db5b18dfc7f7bf84cafaf7cc36e885aafcaeac936977eefd3e4768fd7b2879"}, + {file = "ctranslate2-4.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:253993fbbe20cd7e2602de81e6159b259dadb47b9b59486d928396bd4a4ecdaa"}, + {file = "ctranslate2-4.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1a0509f172edc994aec6870fe0a90c799d85fd7ddf564059d25b60932ab2e2c4"}, + {file = "ctranslate2-4.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c158f2ada6e3347388ad13c69e4a6a729ba40c035a400dd447995950ecf5e62f"}, + {file = "ctranslate2-4.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:de3c5877fce31a0fcf3b5edbc8d4e6e22fd94a86c6b49680740ef41130efffc1"}, + {file = "ctranslate2-4.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:a16a784ec7924166bdf3e86754feda0441f04d9851fc3412f34f1e2de7cbd51b"}, + {file = "ctranslate2-4.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:7c221153ecdda81e24679a07f0b577926879325a0347a89f8afaf2593641cb9b"}, + {file = "ctranslate2-4.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:919a5feab74f33694b66c0a5637f07ba7cf4995af87d960aca50e4cbe53b4054"}, + {file = "ctranslate2-4.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45a45dabca3f9d8eb718685a792f9a7fc10af7362d318271181f16ebf54669b8"}, + {file = "ctranslate2-4.5.0-cp38-cp38-win_amd64.whl", hash = "sha256:5924e9adeff8b30ca0851e0f5ff13639d08e47d1219d27f615c0936a3cdedb57"}, + {file = "ctranslate2-4.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d9e8120817c51515175ab163655dc14b4e21eb381d7196fd43b843b0d50efaf1"}, + {file = "ctranslate2-4.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f790e77458b83e109a743d0f07e9e5c023208314f5c824c26d1e3ebc62a12f71"}, + {file = "ctranslate2-4.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af82185aa961869362c06ce33443b5207237790233b1614ccf92307a671aa72"}, + {file = "ctranslate2-4.5.0-cp39-cp39-win_amd64.whl", hash = "sha256:ccbccbdddb02e7c3b24666f2bc52cd475ca666fda8a317d23a97645eafd66dbe"}, +] + +[package.dependencies] +numpy = "*" +pyyaml = ">=5.3,<7" +setuptools = "*" + [[package]] name = "cycler" version = "0.12.1" @@ -1386,6 +1420,29 @@ files = [ {file = "fasteners-0.19.tar.gz", hash = "sha256:b4f37c3ac52d8a445af3a66bce57b33b5e90b97c696b7b984f530cf8f0ded09c"}, ] +[[package]] +name = "faster-whisper" +version = "1.1.1" +description = "Faster Whisper transcription with CTranslate2" +optional = false +python-versions = ">=3.9" +files = [ + {file = "faster-whisper-1.1.1.tar.gz", hash = "sha256:50d27571970c1be0c2b2680a2593d5d12f9f5d2f10484f242a1afbe7cb946604"}, + {file = "faster_whisper-1.1.1-py3-none-any.whl", hash = "sha256:5808dc334fb64fb4336921450abccfe5e313a859b31ba61def0ac7f639383d90"}, +] + +[package.dependencies] +av = ">=11" +ctranslate2 = ">=4.0,<5" +huggingface-hub = ">=0.13" +onnxruntime = ">=1.14,<2" +tokenizers = ">=0.13,<1" +tqdm = "*" + +[package.extras] +conversion = ["transformers[torch] (>=4.23)"] +dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"] + [[package]] name = "filelock" version = "3.16.1" @@ -2763,8 +2820,8 @@ langchain-core = ">=0.3.29,<0.4.0" langchain-text-splitters = ">=0.3.3,<0.4.0" langsmith = ">=0.1.17,<0.3" numpy = [ - {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, + {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=2.7.4,<3.0.0" PyYAML = ">=5.3" @@ -2787,8 +2844,8 @@ files = [ boto3 = ">=1.35.74" langchain-core = ">=0.3.27,<0.4.0" numpy = [ - {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<3", markers = "python_version >= \"3.12\""}, + {version = ">=1,<2", markers = "python_version < \"3.12\""}, ] pydantic = ">=2,<3" @@ -2811,8 +2868,8 @@ langchain = ">=0.3.14,<0.4.0" langchain-core = ">=0.3.29,<0.4.0" langsmith = ">=0.1.125,<0.3" numpy = [ - {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.2,<3", markers = "python_version >= \"3.12\""}, + {version = ">=1.22.4,<2", markers = "python_version < \"3.12\""}, ] pydantic-settings = ">=2.4.0,<3.0.0" PyYAML = ">=5.3" @@ -4165,10 +4222,10 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4189,10 +4246,10 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.5", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\" and python_version < \"3.11\""}, {version = ">=1.21.2", markers = "platform_system != \"Darwin\" and python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] [[package]] @@ -4396,9 +4453,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -6841,8 +6898,8 @@ files = [ contourpy = {version = ">=1.0.7", markers = "python_version >= \"3.8\" and python_version < \"3.13\""} defusedxml = ">=0.7.1,<0.8.0" matplotlib = [ - {version = ">=3.6.0", markers = "python_version >= \"3.9\" and python_version < \"3.12\""}, {version = ">=3.7.3", markers = "python_version >= \"3.12\""}, + {version = ">=3.6.0", markers = "python_version >= \"3.9\" and python_version < \"3.12\""}, ] numpy = {version = ">=1.21.2", markers = "python_version < \"3.13\""} opencv-python = ">=4.5.5.64" @@ -8175,4 +8232,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.10, <3.13" -content-hash = "ee424289e94a1e02622089d2226e5b97a4ca2d54e9de0787487e81353d11814e" +content-hash = "da1a7720082bf43b4efc7cd972b63f39882fc7e0d69340bbc436f18e889e55b2" diff --git a/pyproject.toml b/pyproject.toml index 4457ea4c..3534c688 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ tomli = "^2.0.1" openwakeword = { git = "https://github.com/maciejmajek/openWakeWord.git", branch = "chore/remove-tflite-backend" } pytest-timeout = "^2.3.1" tomli-w = "^1.1.0" +faster-whisper = "^1.1.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.29.4" diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index 377b3316..802a422d 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -18,6 +18,7 @@ import numpy as np import torch import whisper +from faster_whisper import WhisperModel from numpy._typing import NDArray from rai_asr.models.base import BaseTranscriptionModel @@ -43,3 +44,19 @@ def transcribe(self, data: NDArray[np.int16]) -> str: transcription = cast(str, transcription) self.latest_transcription = transcription return transcription + + +class FasterWhisper(BaseTranscriptionModel): + def __init__( + self, model_name: str, sample_rate: int, language: str = "en", **kwargs + ): + super().__init__(model_name, sample_rate, language) + self.model = WhisperModel(model_name, **kwargs) + self.logger = logging.getLogger(__name__) + + def transcribe(self, data: NDArray[np.int16]) -> str: + normalized_data = data.astype(np.float32) / 32768.0 + segments, _ = self.model.transcribe(normalized_data) + transcription = " ".join(segment.text for segment in segments) + self.logger.info("transcription: %s", transcription) + return transcription From d40653a2716e5cd3867c2c5fd55fcd6071e38241 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 24 Jan 2025 16:54:21 +0100 Subject: [PATCH 15/16] feat: add call method --- src/rai/rai/agents/voice_agent.py | 4 ++-- src/rai_asr/rai_asr/models/base.py | 7 ++++++- src/rai_asr/rai_asr/models/open_wake_word.py | 2 +- src/rai_asr/rai_asr/models/silero_vad.py | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/rai/rai/agents/voice_agent.py b/src/rai/rai/agents/voice_agent.py index 3b7ad90a..3fc77025 100644 --- a/src/rai/rai/agents/voice_agent.py +++ b/src/rai/rai/agents/voice_agent.py @@ -133,7 +133,7 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]): self.transcription_threads[thread_id]["thread"].join() self.transcription_threads[thread_id]["joined"] = True - voice_detected, output_parameters = self.vad.detected(indata, {}) + voice_detected, output_parameters = self.vad(indata, {}) should_record = False # TODO: second condition is temporary if voice_detected and not self.recording_started: @@ -179,7 +179,7 @@ def should_record( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> bool: for model in self.should_record_pipeline: - detected, output = model.detected(audio_data, input_parameters) + detected, output = model(audio_data, input_parameters) if detected: return True return False diff --git a/src/rai_asr/rai_asr/models/base.py b/src/rai_asr/rai_asr/models/base.py index d45a6176..13142df8 100644 --- a/src/rai_asr/rai_asr/models/base.py +++ b/src/rai_asr/rai_asr/models/base.py @@ -22,8 +22,13 @@ class BaseVoiceDetectionModel(ABC): + def __call__( + self, audio_data: NDArray, input_parameters: dict[str, Any] + ) -> Tuple[bool, dict[str, Any]]: + return self.detect(audio_data, input_parameters) + @abstractmethod - def detected( + def detect( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: pass diff --git a/src/rai_asr/rai_asr/models/open_wake_word.py b/src/rai_asr/rai_asr/models/open_wake_word.py index 5c0d9e74..1fb4211e 100644 --- a/src/rai_asr/rai_asr/models/open_wake_word.py +++ b/src/rai_asr/rai_asr/models/open_wake_word.py @@ -34,7 +34,7 @@ def __init__(self, wake_word_model_path: str, threshold: float = 0.5): ) self.threshold = threshold - def detected( + def detect( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: predictions = self.model.predict(audio_data) diff --git a/src/rai_asr/rai_asr/models/silero_vad.py b/src/rai_asr/rai_asr/models/silero_vad.py index 34ad0135..fdecb8b5 100644 --- a/src/rai_asr/rai_asr/models/silero_vad.py +++ b/src/rai_asr/rai_asr/models/silero_vad.py @@ -48,7 +48,7 @@ def int2float(self, sound: NDArray[np.int16]): converted_sound = converted_sound.squeeze() return converted_sound - def detected( + def detect( self, audio_data: NDArray, input_parameters: dict[str, Any] ) -> Tuple[bool, dict[str, Any]]: vad_confidence = self.model( From 8eaf28ed0376fb97a65c0b5afa6cb5061cec879e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kajetan=20Rachwa=C5=82?= Date: Fri, 24 Jan 2025 16:56:28 +0100 Subject: [PATCH 16/16] feat: add kwargs to TTS models --- src/rai_asr/rai_asr/models/local_whisper.py | 8 +++++--- src/rai_asr/rai_asr/models/open_ai_whisper.py | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/rai_asr/rai_asr/models/local_whisper.py b/src/rai_asr/rai_asr/models/local_whisper.py index 802a422d..f8292e33 100644 --- a/src/rai_asr/rai_asr/models/local_whisper.py +++ b/src/rai_asr/rai_asr/models/local_whisper.py @@ -25,12 +25,14 @@ class LocalWhisper(BaseTranscriptionModel): - def __init__(self, model_name: str, sample_rate: int, language: str = "en"): + def __init__( + self, model_name: str, sample_rate: int, language: str = "en", **kwargs + ): super().__init__(model_name, sample_rate, language) if torch.cuda.is_available(): - self.whisper = whisper.load_model(self.model_name, device="cuda") + self.whisper = whisper.load_model(self.model_name, device="cuda", **kwargs) else: - self.whisper = whisper.load_model(self.model_name) + self.whisper = whisper.load_model(self.model_name, **kwargs) self.logger = logging.getLogger(__name__) diff --git a/src/rai_asr/rai_asr/models/open_ai_whisper.py b/src/rai_asr/rai_asr/models/open_ai_whisper.py index a5d30ba7..0f74dd09 100644 --- a/src/rai_asr/rai_asr/models/open_ai_whisper.py +++ b/src/rai_asr/rai_asr/models/open_ai_whisper.py @@ -26,7 +26,9 @@ class OpenAIWhisper(BaseTranscriptionModel): - def __init__(self, model_name: str, sample_rate: int, language: str = "en"): + def __init__( + self, model_name: str, sample_rate: int, language: str = "en", **kwargs + ): super().__init__(model_name, sample_rate, language) api_key = os.getenv("OPENAI_API_KEY") if api_key is None: @@ -36,6 +38,7 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"): self.model = partial( self.openai_client.audio.transcriptions.create, model=self.model_name, + **kwargs, ) self.logger = logging.getLogger(__name__) self.samples = []