Skip to content

Commit

Permalink
feat: basic multithreading implementation of transcription
Browse files Browse the repository at this point in the history
  • Loading branch information
rachwalk committed Jan 9, 2025
1 parent 5dfbcf9 commit 4826a5e
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 80 deletions.
4 changes: 0 additions & 4 deletions src/rai/rai/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ def __init__(
connectors = {}
self.connectors: dict[str, BaseConnector] = connectors

@abstractmethod
def setup(self, *args, **kwargs):
pass

@abstractmethod
def run(self, *args, **kwargs):
pass
171 changes: 115 additions & 56 deletions src/rai/rai/agents/voice_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,61 @@


import time
from threading import Lock, Thread
from typing import Any, List, cast
from threading import Event, Lock, Thread
from typing import Any, List, TypedDict
from uuid import uuid4

import numpy as np
from numpy.typing import NDArray

from rai.agents.base import BaseAgent
from rai.communication import AudioInputDeviceConfig, StreamingAudioInputDevice
from rai_asr.models.base import BaseTranscriptionModel, BaseVoiceDetectionModel
from rai_asr.models import BaseTranscriptionModel, BaseVoiceDetectionModel


class VoiceRecognitionAgent(BaseAgent):
def __init__(self):
super().__init__(connectors={"microphone": StreamingAudioInputDevice()})
self.should_record_pipeline: List[BaseVoiceDetectionModel] = []
self.should_stop_pipeline: List[BaseVoiceDetectionModel] = []
self.transcription_lock = Lock()
self.shared_samples = []
self.recording_started = False
self.ran_setup = False
class ThreadData(TypedDict):
thread: Thread
event: Event
transcription: str

def __call__(self):
self.run()

def setup(
class VoiceRecognitionAgent(BaseAgent):
def __init__(
self,
microphone_device_id: int, # TODO: Change to name based instead of id based identification
microphone_config: AudioInputDeviceConfig,
transcription_model: BaseTranscriptionModel,
vad: BaseVoiceDetectionModel,
grace_period: float = 1.0,
):
self.connectors["microphone"] = cast(
StreamingAudioInputDevice, self.connectors["microphone"]
microphone = StreamingAudioInputDevice()
microphone.configure_device(
target=str(microphone_device_id), config=microphone_config
)
super().__init__(connectors={"microphone": microphone})
self.microphone_device_id = str(microphone_device_id)
self.connectors["microphone"].configure_device(
target=self.microphone_device_id, config=microphone_config
)
self.should_record_pipeline: List[BaseVoiceDetectionModel] = []
self.should_stop_pipeline: List[BaseVoiceDetectionModel] = []

self.transcription_model = transcription_model
self.ran_setup = True
self.running = False
self.transcription_lock = Lock()

self.vad: BaseVoiceDetectionModel = vad

self.grace_period = grace_period
self.grace_period_start = 0

self.recording_started = False
self.ran_setup = False

self.sample_buffer = []
self.sample_buffer_lock = Lock()
self.active_thread = ""
self.transcription_threads: dict[str, ThreadData] = {}
self.buffer_reminders: dict[str, list[NDArray]] = {}

def __call__(self):
self.run()

def add_detection_model(
self, model: BaseVoiceDetectionModel, pipeline: str = "record"
Expand All @@ -70,49 +85,93 @@ def run(self):
self.listener_handle = self.connectors["microphone"].start_action(
self.microphone_device_id, self.on_new_sample
)
self.transcription_thread = Thread(target=self._transcription_function)
self.transcription_thread.start()

def stop(self):
self.running = False
self.connectors["microphone"].terminate_action(self.listener_handle)
self.transcription_thread.join()
to_finish = list(self.transcription_threads.keys())
while len(to_finish) > 0:
for thread_id in self.transcription_threads:
if self.transcription_threads[thread_id]["event"].is_set():
self.transcription_threads[thread_id]["thread"].join()
to_finish.remove(thread_id)
else:
print(f"Waiting for transcription of {thread_id} to finish")

def on_new_sample(self, indata: np.ndarray, status_flags: dict[str, Any]):
should_stop = self.should_stop_recording(indata)
if self.should_start_recording(indata):
sample_time = time.time()
with self.sample_buffer_lock:
self.sample_buffer.append(indata)
if not self.recording_started and len(self.sample_buffer) > 5:
self.sample_buffer = self.sample_buffer[-5:]

voice_detected, output_parameters = self.vad.detected(indata, {})
should_record = False
# TODO: second condition is temporary
if voice_detected and not self.recording_started:
should_record = self.should_record(indata, output_parameters)

if should_record:
print("Start recording")
self.recording_started = True
if self.recording_started and not should_stop:
with self.transcription_lock:
self.shared_samples.extend(indata)
thread_id = str(uuid4())[0:8]
transcription_thread = Thread(
target=self.transcription_thread,
args=[thread_id],
)
transcription_finished = Event()
self.active_thread = thread_id
transcription_thread.start()
self.transcription_threads[thread_id] = {
"thread": transcription_thread,
"event": transcription_finished,
"transcription": "",
}

if voice_detected:
self.grace_period_start = sample_time

def should_start_recording(self, audio_data: NDArray[np.int16]) -> bool:
output_parameters = {}
if (
self.recording_started
and sample_time - self.grace_period_start > self.grace_period
):
print("Stop recording")
self.recording_started = False
self.grace_period_start = 0
with self.sample_buffer_lock:
self.buffer_reminders[self.active_thread] = self.sample_buffer
self.sample_buffer = []
self.active_thread = ""

def should_record(
self, audio_data: NDArray, input_parameters: dict[str, Any]
) -> bool:
for model in self.should_record_pipeline:
should_listen, output_parameters = model.detected(
audio_data, output_parameters
)
print(should_listen, output_parameters)
if not should_listen:
return False
return True

def should_stop_recording(self, audio_data: NDArray[np.int16]) -> bool:
output_parameters = {}
for model in self.should_stop_pipeline:
should_listen, output_parameters = model.detected(
audio_data, output_parameters
)
if should_listen:
detected, output = model.detected(audio_data, input_parameters)
print(f"Detected: {detected}: {output}")
if detected:
return True
return False

def _transcription_function(self):
while self.running:
time.sleep(0.1)
# critical section for samples
with self.transcription_lock:
samples = np.array(self.shared_samples)
self.shared_samples = []
# end critical section for samples
self.transcription_model.add_samples(samples)
def transcription_thread(self, identifier: str):
with self.transcription_lock:
while self.active_thread == identifier:
with self.sample_buffer_lock:
if len(self.sample_buffer) == 0:
continue
audio_data = self.sample_buffer.copy()
self.sample_buffer = []
audio_data = np.concatenate(audio_data)
self.transcription_model.transcribe(audio_data)

# transciption of the reminder of the buffer
with self.sample_buffer_lock:
if identifier in self.buffer_reminders:
audio_data = self.buffer_reminders[identifier]
audio_data = np.concatenate(audio_data)
self.transcription_model.transcribe(audio_data)
del self.buffer_reminders[identifier]
transcription = self.transcription_model.consume_transcription()
self.transcription_threads[identifier]["transcription"] = transcription
self.transcription_threads[identifier]["event"].set()
# TODO: sending the transcription
11 changes: 7 additions & 4 deletions src/rai_asr/rai_asr/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@ def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
self.sample_rate = sample_rate
self.language = language

@abstractmethod
def add_samples(self, data: NDArray[np.int16]):
pass
self.latest_transcription = ""

def consume_transcription(self) -> str:
ret = self.latest_transcription
self.latest_transcription = ""
return ret

@abstractmethod
def transcribe(self) -> str:
def transcribe(self, data: NDArray[np.int16]):
pass
24 changes: 11 additions & 13 deletions src/rai_asr/rai_asr/models/local_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import cast

import numpy as np
import torch
import whisper
from numpy._typing import NDArray

Expand All @@ -24,24 +25,21 @@
class LocalWhisper(BaseTranscriptionModel):
def __init__(self, model_name: str, sample_rate: int, language: str = "en"):
super().__init__(model_name, sample_rate, language)
self.whisper = whisper.load_model(self.model_name)
if torch.cuda.is_available():
print("Using CUDA")
self.whisper = whisper.load_model(self.model_name, device="cuda")
else:
self.whisper = whisper.load_model(self.model_name)

self.samples = None

def add_samples(self, data: NDArray[np.int16]):
def transcribe(self, data: NDArray[np.int16]):
normalized_data = data.astype(np.float32) / 32768.0
self.samples = (
np.concatenate([self.samples, normalized_data])
if self.samples is not None
else data
)

def transcribe(self) -> str:
if self.samples is None:
raise ValueError("No samples to transcribe")
print("Starting transcription")
result = whisper.transcribe(
self.whisper, self.samples
self.whisper, normalized_data
) # TODO: handling of additional transcribe arguments (perhaps in model init)
print("Finished transcription")
transcription = result["text"]
transcription = cast(str, transcription)
return transcription
self.latest_transcription += transcription
3 changes: 1 addition & 2 deletions src/rai_asr/rai_asr/models/open_wake_word.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,11 @@ def __init__(self, wake_word_model_path: str, threshold: float = 0.5):
def detected(
self, audio_data: NDArray, input_parameters: dict[str, Any]
) -> Tuple[bool, dict[str, Any]]:
print(len(audio_data))
predictions = self.model.predict(audio_data)
ret = input_parameters.copy()
ret.update({self.model_name: {"predictions": predictions}})
for key, value in predictions.items():
if value > self.threshold:
self.model.reset()
# self.model.reset()
return True, ret
return False, ret
1 change: 0 additions & 1 deletion src/rai_asr/rai_asr/models/silero_vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,5 @@ def detected(
).item()
ret = input_parameters.copy()
ret.update({self.model_name: {"vad_confidence": vad_confidence}})
self.model.reset_states() # NOTE: see streaming example at the bottom https://github.com/snakers4/silero-vad/wiki/Examples-and-Dependencies#dependencies

return vad_confidence > self.threshold, ret

0 comments on commit 4826a5e

Please sign in to comment.