Skip to content

Commit

Permalink
feat: functional multithreading transcription
Browse files Browse the repository at this point in the history
  • Loading branch information
rachwalk committed Jan 15, 2025
1 parent e9de014 commit 225178f
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 20 deletions.
32 changes: 23 additions & 9 deletions src/rai/rai/agents/voice_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.


import logging
import time
from threading import Event, Lock, Thread
from typing import Any, List, TypedDict
from typing import Any, List, Optional, TypedDict
from uuid import uuid4

import numpy as np
Expand All @@ -40,7 +41,12 @@ def __init__(
transcription_model: BaseTranscriptionModel,
vad: BaseVoiceDetectionModel,
grace_period: float = 1.0,
logger: Optional[logging.Logger] = None,
):
if logger is None:
self.logger = logging.getLogger(__name__)
else:
self.logger = logger
microphone = StreamingAudioInputDevice()
microphone.configure_device(
target=str(microphone_device_id), config=microphone_config
Expand Down Expand Up @@ -87,16 +93,20 @@ def run(self):
)

def stop(self):
self.logger.info("Stopping voice agent")
self.running = False
self.connectors["microphone"].terminate_action(self.listener_handle)
to_finish = list(self.transcription_threads.keys())
while len(to_finish) > 0:
to_finish = len(list(self.transcription_threads.keys()))
while to_finish > 0:
for thread_id in self.transcription_threads:
if self.transcription_threads[thread_id]["event"].is_set():
self.transcription_threads[thread_id]["thread"].join()
to_finish.remove(thread_id)
to_finish -= 1
else:
print(f"Waiting for transcription of {thread_id} to finish")
self.logger.info(
f"Waiting for transcription of {thread_id} to finish..."
)
self.logger.info("Voice agent stopped")

def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
sample_time = time.time()
Expand All @@ -112,7 +122,7 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
should_record = self.should_record(indata, output_parameters)

if should_record:
print("Start recording")
self.logger.info("starting recording...")
self.recording_started = True
thread_id = str(uuid4())[0:8]
transcription_thread = Thread(
Expand All @@ -129,13 +139,14 @@ def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
}

if voice_detected:
self.logger.debug("Voice detected... resetting grace period")
self.grace_period_start = sample_time

if (
self.recording_started
and sample_time - self.grace_period_start > self.grace_period
):
print("Stop recording")
self.logger.info("Grace period ended... stopping recording")
self.recording_started = False
self.grace_period_start = 0
with self.sample_buffer_lock:
Expand All @@ -148,12 +159,12 @@ def should_record(
) -> bool:
for model in self.should_record_pipeline:
detected, output = model.detected(audio_data, input_parameters)
print(f"Detected: {detected}: {output}")
if detected:
return True
return False

def transcription_thread(self, identifier: str):
self.logger.info(f"transcription thread {identifier} started")
with self.transcription_lock:
while self.active_thread == identifier:
with self.sample_buffer_lock:
Expand All @@ -171,7 +182,10 @@ def transcription_thread(self, identifier: str):
audio_data = np.concatenate(audio_data)
self.transcription_model.transcribe(audio_data)
del self.buffer_reminders[identifier]
# self.transcription_model.save_wav(f"{identifier}.wav")
transcription = self.transcription_model.consume_transcription()
self.transcription_threads[identifier]["transcription"] = transcription
self.transcription_threads[identifier]["event"].set()
# TODO: sending the transcription
# TODO: sending the transcription once https://github.com/RobotecAI/rai/pull/360 is merged
self.logger.info(f"transcription thread {identifier} finished")
print(transcription)
9 changes: 3 additions & 6 deletions src/rai/rai/communication/sound_device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def __init__(self, msg: str):
class AudioInputDeviceConfig(TypedDict):
block_size: int
consumer_sampling_rate: int
target_sampling_rate: int
dtype: str
device_number: Optional[int]

Expand All @@ -44,7 +43,6 @@ class ConfiguredAudioInputDevice:
sample_rate (int): Device sample rate
consumer_sampling_rate (int): The sampling rate of the consumer
window_size_samples (int): The size of the window in samples
target_sampling_rate (int): The target sampling rate
dtype (str): The data type of the audio samples
"""

Expand All @@ -58,7 +56,6 @@ def __init__(self, config: AudioInputDeviceConfig):
self.window_size_samples = int(
config["block_size"] * self.sample_rate / config["consumer_sampling_rate"]
)
self.target_sampling_rate = int(config["target_sampling_rate"])
self.dtype = config["dtype"]


Expand Down Expand Up @@ -108,9 +105,9 @@ def start_action(

def callback(indata: np.ndarray, frames: int, _, status: CallbackFlags):
indata = indata.flatten()
sample_time_length = len(indata) / target_device.target_sampling_rate
if target_device.sample_rate != target_device.target_sampling_rate:
indata = resample(indata, int(sample_time_length * target_device.target_sampling_rate)) # type: ignore
sample_time_length = len(indata) / target_device.sample_rate
if target_device.sample_rate != target_device.consumer_sampling_rate:
indata = resample(indata, int(sample_time_length * target_device.consumer_sampling_rate)) # type: ignore
flag_dict = {
"input_overflow": status.input_overflow,
"input_underflow": status.input_underflow,
Expand Down
31 changes: 27 additions & 4 deletions src/rai_asr/rai_asr/models/local_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,43 @@ class LocalWhisper(BaseTranscriptionModel):
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
super().__init__(model_name, sample_rate, language)
if torch.cuda.is_available():
print("Using CUDA")
self.whisper = whisper.load_model(self.model_name, device="cuda")
else:
self.whisper = whisper.load_model(self.model_name)

self.samples = None
# TODO: remove sample storage before PR is merged, this is just to enable saving wav files for debugging
# self.samples = None

def consume_transcription(self) -> str:
ret = super().consume_transcription()
# self.samples = None
return ret

# def save_wav(self, output_filename: str):
# assert self.samples is not None, "No samples to save"
# combined_samples = self.samples
# if combined_samples.dtype.kind == "f":
# combined_samples = np.clip(combined_samples, -1.0, 1.0)
# combined_samples = (combined_samples * 32767).astype(np.int16)
# elif combined_samples.dtype != np.int16:
# combined_samples = combined_samples.astype(np.int16)

# with wave.open(output_filename, "wb") as wav_file:
# n_channels = 1
# sampwidth = 2
# wav_file.setnchannels(n_channels)
# wav_file.setsampwidth(sampwidth)
# wav_file.setframerate(self.sample_rate)
# wav_file.writeframes(combined_samples.tobytes())

def transcribe(self, data: NDArray[np.int16]):
# self.samples = (
# np.concatenate((self.samples, data)) if self.samples is not None else data
# )
normalized_data = data.astype(np.float32) / 32768.0
print("Starting transcription")
result = whisper.transcribe(
self.whisper, normalized_data
) # TODO: handling of additional transcribe arguments (perhaps in model init)
print("Finished transcription")
transcription = result["text"]
transcription = cast(str, transcription)
self.latest_transcription += transcription
1 change: 0 additions & 1 deletion tests/communication/test_sound_device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def device_config():
return {
"block_size": 1024,
"consumer_sampling_rate": 44100,
"target_sampling_rate": 16000,
"dtype": "float32",
}

Expand Down

0 comments on commit 225178f

Please sign in to comment.