Skip to content

Commit

Permalink
VAD fallback (#97)
Browse files Browse the repository at this point in the history
* Silero VAD preferred with webrtc fallback

* webrtc VAD neds a different sample size

* fixup

* fixup
  • Loading branch information
chadbailey59 authored Apr 4, 2024
1 parent 385b51a commit 03ea208
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 20 deletions.
25 changes: 25 additions & 0 deletions src/dailyai/transports/daily_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import types

from enum import Enum
from functools import partial
from typing import Any

Expand Down Expand Up @@ -33,6 +34,11 @@

from dailyai.transports.threaded_transport import ThreadedTransport

NUM_CHANNELS = 1

SPEECH_THRESHOLD = 0.90
VAD_RESET_PERIOD_MS = 2000


class DailyTransport(ThreadedTransport, EventHandler):
_daily_initialized = False
Expand All @@ -55,6 +61,7 @@ def __init__(
start_transcription: bool = False,
**kwargs,
):
kwargs['has_webrtc_vad'] = True
# This will call ThreadedTransport.__init__ method, not EventHandler
super().__init__(**kwargs)

Expand Down Expand Up @@ -86,6 +93,12 @@ def __init__(

self._event_handlers = {}

self.webrtc_vad = Daily.create_native_vad(
reset_period_ms=VAD_RESET_PERIOD_MS,
sample_rate=self._speaker_sample_rate,
channels=NUM_CHANNELS
)

def _patch_method(self, event_name, *args, **kwargs):
try:
for handler in self._event_handlers[event_name]:
Expand All @@ -106,6 +119,18 @@ def _patch_method(self, event_name, *args, **kwargs):
self._logger.error(f"Exception in event handler {event_name}: {e}")
raise e

def _webrtc_vad_analyze(self):
buffer = self.read_audio_frames(
int(self._vad_samples))
if len(buffer) > 0:
confidence = self.webrtc_vad.analyze_frames(buffer)
# yeses = int(confidence * 20.0)
# nos = 20 - yeses
# out = "!" * yeses + "." * nos
# print(f"!!! confidence: {out} {confidence}")
talking = confidence > SPEECH_THRESHOLD
return talking

def add_event_handler(self, event_name: str, handler):
if not event_name.startswith("on_"):
raise Exception(
Expand Down
53 changes: 33 additions & 20 deletions src/dailyai/transports/threaded_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ def int2float(sound):
return sound


SAMPLE_RATE = 16000


class VADState(Enum):
QUIET = 1
STARTING = 2
Expand All @@ -61,11 +58,12 @@ def __init__(
self._vad_stop_s = kwargs.get("vad_stop_s") or 0.8
self._context = kwargs.get("context") or []
self._vad_enabled = kwargs.get("vad_enabled") or False

self._has_webrtc_vad = kwargs.get("has_webrtc_vad") or False
if self._vad_enabled and self._speaker_enabled:
raise Exception(
"Sorry, you can't use speaker_enabled and vad_enabled at the same time. Please set one to False."
)
self._vad_samples = 1536

if self._vad_enabled:
try:
Expand All @@ -79,14 +77,19 @@ def __init__(
(self.model, self.utils) = torch.hub.load(
repo_or_dir="snakers4/silero-vad", model="silero_vad", force_reload=False
)
self._logger.debug("Loaded Silero VAD")

except ModuleNotFoundError as e:
print(f"Exception: {e}")
print("In order to use VAD, you'll need to install the `torch` and `torchaudio` modules.")
raise Exception(f"Missing module(s): {e}")
if self._has_webrtc_vad:
self._logger.debug(f"Couldn't load torch; using webrtc VAD")
self._vad_samples = int(self._speaker_sample_rate / 100.0)
else:
self._logger.error(f"Exception: {e}")
self._logger.error(
"In order to use VAD, you'll need to install the `torch` and `torchaudio` modules.")
raise Exception(f"Missing module(s): {e}")

self._vad_samples = 1536
vad_frame_s = self._vad_samples / SAMPLE_RATE
vad_frame_s = self._vad_samples / self._speaker_sample_rate
self._vad_start_frames = round(self._vad_start_s / vad_frame_s)
self._vad_stop_frames = round(self._vad_stop_s / vad_frame_s)
self._vad_starting_count = 0
Expand Down Expand Up @@ -262,19 +265,28 @@ def read_audio_frames(self, desired_frame_count):
def _prerun(self):
pass

def _silero_vad_analyze(self):
audio_chunk = self.read_audio_frames(self._vad_samples)
audio_int16 = np.frombuffer(audio_chunk, np.int16)
audio_float32 = int2float(audio_int16)
new_confidence = self.model(
torch.from_numpy(audio_float32), 16000).item()
# yeses = int(new_confidence * 20.0)
# nos = 20 - yeses
# out = "!" * yeses + "." * nos
# print(f"!!! confidence: {out}")
speaking = new_confidence > 0.5
return speaking

def _vad(self):
# CB: Starting silero VAD stuff
# TODO-CB: Probably need to force virtual speaker creation if we're
# going to build this in?
# TODO-CB: pyaudio installation
while not self._stop_threads.is_set():
audio_chunk = self.read_audio_frames(self._vad_samples)
audio_int16 = np.frombuffer(audio_chunk, np.int16)
audio_float32 = int2float(audio_int16)
new_confidence = self.model(
torch.from_numpy(audio_float32), 16000).item()
speaking = new_confidence > 0.5

while not self._stop_threads.is_set():
if hasattr(self, 'model'): # we can use Silero
speaking = self._silero_vad_analyze()
elif self._has_webrtc_vad:
speaking = self._webrtc_vad_analyze()
else:
raise Exception("VAD is running with no VAD service available")
if speaking:
match self._vad_state:
case VADState.QUIET:
Expand Down Expand Up @@ -311,6 +323,7 @@ def _vad(self):
self._vad_state == VADState.STOPPING
and self._vad_stopping_count >= self._vad_stop_frames
):

if self._loop:
asyncio.run_coroutine_threadsafe(
self.receive_queue.put(
Expand Down

0 comments on commit 03ea208

Please sign in to comment.