diff --git a/Audio-Transcription-Chrome/background.js b/Audio-Transcription-Chrome/background.js
index a5028a78..52ce4a7e 100644
--- a/Audio-Transcription-Chrome/background.js
+++ b/Audio-Transcription-Chrome/background.js
@@ -157,7 +157,8 @@ async function startCapture(options) {
multilingual: options.useMultilingual,
language: options.language,
task: options.task,
- modelSize: options.modelSize
+ modelSize: options.modelSize,
+ useVad: options.useVad,
},
});
} else {
diff --git a/Audio-Transcription-Chrome/options.js b/Audio-Transcription-Chrome/options.js
index 435d75b2..6c3ef628 100644
--- a/Audio-Transcription-Chrome/options.js
+++ b/Audio-Transcription-Chrome/options.js
@@ -99,7 +99,8 @@ async function startRecord(option) {
uid: uuid,
language: option.language,
task: option.task,
- model: option.modelSize
+ model: option.modelSize,
+ use_vad: option.useVad
})
);
};
diff --git a/Audio-Transcription-Chrome/popup.html b/Audio-Transcription-Chrome/popup.html
index 8d45bcb2..13203368 100644
--- a/Audio-Transcription-Chrome/popup.html
+++ b/Audio-Transcription-Chrome/popup.html
@@ -15,6 +15,10 @@
+
+
+
+
+
+
+
+
diff --git a/Audio-Transcription-Firefox/popup.js b/Audio-Transcription-Firefox/popup.js
index d71dc0a1..61c37a5b 100644
--- a/Audio-Transcription-Firefox/popup.js
+++ b/Audio-Transcription-Firefox/popup.js
@@ -3,6 +3,7 @@ document.addEventListener("DOMContentLoaded", function() {
const stopButton = document.getElementById("stopCapture");
const useServerCheckbox = document.getElementById("useServerCheckbox");
+ const useVadCheckbox = document.getElementById("useVadCheckbox");
const languageDropdown = document.getElementById('languageDropdown');
const taskDropdown = document.getElementById('taskDropdown');
const modelSizeDropdown = document.getElementById('modelSizeDropdown');
@@ -34,6 +35,12 @@ document.addEventListener("DOMContentLoaded", function() {
}
});
+ browser.storage.local.get("useVadState", ({ useVadState }) => {
+ if (useVadState !== undefined) {
+ useVadCheckbox.checked = useVadState;
+ }
+ });
+
browser.storage.local.get("selectedLanguage", ({ selectedLanguage: storedLanguage }) => {
if (storedLanguage !== undefined) {
languageDropdown.value = storedLanguage;
@@ -76,7 +83,8 @@ document.addEventListener("DOMContentLoaded", function() {
port: port,
language: selectedLanguage,
task: selectedTask,
- modelSize: selectedModelSize
+ modelSize: selectedModelSize,
+ useVad: useVadCheckbox.checked,
}
});
toggleCaptureButtons(true);
@@ -115,6 +123,7 @@ document.addEventListener("DOMContentLoaded", function() {
startButton.disabled = isCapturing;
stopButton.disabled = !isCapturing;
useServerCheckbox.disabled = isCapturing;
+ useVadCheckbox.disabled = isCapturing;
modelSizeDropdown.disabled = isCapturing;
languageDropdown.disabled = isCapturing;
taskDropdown.disabled = isCapturing;
@@ -128,6 +137,11 @@ document.addEventListener("DOMContentLoaded", function() {
browser.storage.local.set({ useServerState });
});
+ useVadCheckbox.addEventListener("change", () => {
+ const useVadState = useVadCheckbox.checked;
+ browser.storage.local.set({ useVadState });
+ });
+
languageDropdown.addEventListener('change', function() {
if (languageDropdown.value === "") {
selectedLanguage = null;
diff --git a/README.md b/README.md
index f62292e4..addbfbf2 100644
--- a/README.md
+++ b/README.md
@@ -64,7 +64,8 @@ client = TranscriptionClient(
9090,
lang="en",
translate=False,
- model="small"
+ model="small",
+ use_vad=False,
)
```
It connects to the server running on localhost at port 9090. Using a multilingual model, language for the transcription will be automatically detected. You can also use the language option to specify the target language for the transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language.
diff --git a/TensorRT_whisper.md b/TensorRT_whisper.md
index 6ff28681..1bc303f0 100644
--- a/TensorRT_whisper.md
+++ b/TensorRT_whisper.md
@@ -21,7 +21,7 @@ docker pull ghcr.io/collabora/whisperbot-base:latest
```bash
docker run -it --gpus all --shm-size=8g \
--ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
- -v /path/to/WhisperLive:/home/WhisperLive \
+ -p 9090:9090 -v /path/to/WhisperLive:/home/WhisperLive \
ghcr.io/collabora/whisperbot-base:latest
```
@@ -48,7 +48,7 @@ bash scripts/build_whisper_tensorrt.sh /root/TensorRT-LLM-examples small
cd /home/WhisperLive
# Install requirements
-bash scripts/setup.sh
+apt update && bash scripts/setup.sh
pip install -r requirements/server.txt
# Required to create mel spectogram
diff --git a/tests/test_client.py b/tests/test_client.py
index 468b5a13..56d5dbce 100644
--- a/tests/test_client.py
+++ b/tests/test_client.py
@@ -46,6 +46,7 @@ def test_on_open(self):
"language": self.client.language,
"task": self.client.task,
"model": self.client.model,
+ "use_vad": True
})
self.client.on_open(self.mock_ws_app)
self.mock_ws_app.send.assert_called_with(expected_message)
diff --git a/tests/test_server.py b/tests/test_server.py
index cd14bb31..e5d630fb 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -69,7 +69,7 @@ def test_recv_audio_exception_handling(self, mock_websocket):
class TestServerInferenceAccuracy(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.server_process = subprocess.Popen(["python", "run_server.py"]) # Adjust the command as needed
+ cls.server_process = subprocess.Popen(["python", "run_server.py"])
time.sleep(2)
@classmethod
@@ -134,4 +134,4 @@ def test_unexpected_exception_handling(self, mock_websocket):
for message in log.output:
print(message)
print()
- self.assertTrue(any("Unexpected error: Unexpected error" in message for message in log.output))
+ self.assertTrue(any("Unexpected error" in message for message in log.output))
diff --git a/whisper_live/client.py b/whisper_live/client.py
index 7a8201c6..a759eaea 100644
--- a/whisper_live/client.py
+++ b/whisper_live/client.py
@@ -17,6 +17,7 @@ class Client:
Handles audio recording, streaming, and communication with a server using WebSocket.
"""
INSTANCES = {}
+ END_OF_AUDIO = "END_OF_AUDIO"
def __init__(
self,
@@ -25,7 +26,8 @@ def __init__(
lang=None,
translate=False,
model="small",
- srt_file_path="output.srt"
+ srt_file_path="output.srt",
+ use_vad=True
):
"""
Initializes a Client instance for audio recording and streaming to a server.
@@ -55,6 +57,8 @@ def __init__(
self.model = model
self.server_error = False
self.srt_file_path = srt_file_path
+ self.use_vad = use_vad
+ self.last_recieved_segment = None
if translate:
self.task = "translate"
@@ -120,6 +124,10 @@ def process_segments(self, segments):
(not self.transcript or
float(seg['start']) >= float(self.transcript[-1]['end']))):
self.transcript.append(seg)
+ # update last received segment and last valild responsne time
+ if self.last_recieved_segment is None or self.last_recieved_segment != segments[-1]["text"]:
+ self.last_response_recieved = time.time()
+ self.last_recieved_segment = segments[-1]["text"]
# Truncate to last 3 entries for brevity.
text = text[-3:]
@@ -139,7 +147,6 @@ def on_message(self, ws, message):
message (str): The received message from the server.
"""
- self.last_response_recieved = time.time()
message = json.loads(message)
if self.uid != message.get("uid"):
@@ -155,6 +162,7 @@ def on_message(self, ws, message):
self.recording = False
if "message" in message.keys() and message["message"] == "SERVER_READY":
+ self.last_response_recieved = time.time()
self.recording = True
self.server_backend = message["backend"]
print(f"[INFO]: Server Running with backend {self.server_backend}")
@@ -201,6 +209,7 @@ def on_open(self, ws):
"language": self.language,
"task": self.task,
"model": self.model,
+ "use_vad": self.use_vad
}
)
)
@@ -275,7 +284,7 @@ def play_file(self, filename):
assert self.last_response_recieved
while time.time() - self.last_response_recieved < self.disconnect_if_no_response_for:
continue
-
+ self.send_packet_to_server(Client.END_OF_AUDIO.encode('utf-8'))
if self.server_backend == "faster_whisper":
self.write_srt_file(self.srt_file_path)
self.stream.close()
@@ -497,8 +506,8 @@ class TranscriptionClient:
transcription_client()
```
"""
- def __init__(self, host, port, lang=None, translate=False, model="small"):
- self.client = Client(host, port, lang, translate, model)
+ def __init__(self, host, port, lang=None, translate=False, model="small", use_vad=True):
+ self.client = Client(host, port, lang, translate, model, srt_file_path="output.srt", use_vad=use_vad)
def __call__(self, audio=None, hls_url=None):
"""
diff --git a/whisper_live/server.py b/whisper_live/server.py
index dc7d5c36..7bebe39c 100644
--- a/whisper_live/server.py
+++ b/whisper_live/server.py
@@ -127,6 +127,7 @@ class TranscriptionServer:
def __init__(self):
self.client_manager = ClientManager()
self.no_voice_activity_chunks = 0
+ self.use_vad = True
def initialize_client(
self, websocket, options, faster_whisper_custom_model_path,
@@ -165,12 +166,11 @@ def initialize_client(
client_uid=options["uid"],
model=options["model"],
initial_prompt=options.get("initial_prompt"),
- vad_parameters=options.get("vad_parameters")
+ vad_parameters=options.get("vad_parameters"),
+ use_vad=self.use_vad,
)
logging.info("Running faster_whisper backend.")
- # self.clients[websocket] = client
- # self.clients_start_time[websocket] = time.time()
self.client_manager.add_client(websocket, client)
def get_audio_from_websocket(self, websocket):
@@ -184,24 +184,54 @@ def get_audio_from_websocket(self, websocket):
A numpy array containing the audio.
"""
frame_data = websocket.recv()
+ if frame_data == b"END_OF_AUDIO":
+ return False
return np.frombuffer(frame_data, dtype=np.float32)
- def handle_new_connection(self, websocket, backend, faster_whisper_custom_model_path,
+ def handle_new_connection(self, websocket, faster_whisper_custom_model_path,
whisper_tensorrt_path, trt_multilingual):
- logging.info("New client connected")
- options = websocket.recv()
- options = json.loads(options)
+ try:
+ logging.info("New client connected")
+ options = websocket.recv()
+ options = json.loads(options)
+ self.use_vad = options.get('use_vad')
+ if self.client_manager.is_server_full(websocket, options):
+ websocket.close()
+ return False # Indicates that the connection should not continue
- if self.client_manager.is_server_full(websocket, options):
- websocket.close()
- return
+ if self.backend == "tensorrt":
+ self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
+ self.initialize_client(websocket, options, faster_whisper_custom_model_path,
+ whisper_tensorrt_path, trt_multilingual)
+ return True
+ except json.JSONDecodeError:
+ logging.error("Failed to decode JSON from client")
+ return False
+ except ConnectionClosed:
+ logging.info("Connection closed by client")
+ return False
+ except Exception as e:
+ logging.error(f"Error during new connection initialization: {str(e)}")
+ return False
- self.backend = backend
- if self.backend == "tensorrt":
- self.vad_detector = VoiceActivityDetector(frame_rate=self.RATE)
+ def process_audio_frames(self, websocket):
+ frame_np = self.get_audio_from_websocket(websocket)
+ client = self.client_manager.get_client(websocket)
+ if frame_np is False:
+ if self.backend == "tensorrt":
+ client.set_eos(True)
+ return False
- self.initialize_client(
- websocket, options, faster_whisper_custom_model_path, whisper_tensorrt_path, trt_multilingual)
+ if self.backend == "tensorrt":
+ voice_active = self.voice_activity(websocket, frame_np)
+ if voice_active:
+ self.no_voice_activity_chunks = 0
+ client.set_eos(False)
+ if self.use_vad and not voice_active:
+ return True
+
+ client.add_frames(frame_np)
+ return True
def recv_audio(self,
websocket,
@@ -233,33 +263,17 @@ def recv_audio(self,
Raises:
Exception: If there is an error during the audio frame processing.
"""
- try:
- self.handle_new_connection(websocket, backend, faster_whisper_custom_model_path,
- whisper_tensorrt_path, trt_multilingual)
+ self.backend = backend
+ if not self.handle_new_connection(websocket, faster_whisper_custom_model_path,
+ whisper_tensorrt_path, trt_multilingual):
+ return
+ try:
while not self.client_manager.is_client_timeout(websocket):
- try:
- frame_np = self.get_audio_from_websocket(websocket)
- client = self.client_manager.get_client(websocket)
-
- # VAD, for faster_whisper VAD model is already integrated
- if self.backend == "tensorrt":
- if not self.voice_activity(websocket, frame_np):
- continue
- self.no_voice_activity_chunks = 0
- client.set_eos(False)
-
- client.add_frames(frame_np)
-
- except Exception as e:
- logging.error(e)
- self.cleanup(websocket)
- websocket.close()
+ if not self.process_audio_frames(websocket):
break
except ConnectionClosed:
- logging.info("Connection closed by client.")
- except json.JSONDecodeError:
- logging.error("Failed to decode JSON from client")
+ logging.info("Connection closed by client")
except Exception as e:
logging.error(f"Unexpected error: {str(e)}")
finally:
@@ -660,7 +674,7 @@ def speech_to_text(self):
class ServeClientFasterWhisper(ServeClientBase):
def __init__(self, websocket, task="transcribe", device=None, language=None, client_uid=None, model="small.en",
- initial_prompt=None, vad_parameters=None):
+ initial_prompt=None, vad_parameters=None, use_vad=True):
"""
Initialize a ServeClient instance.
The Whisper model is initialized based on the client's language and device availability.
@@ -702,6 +716,7 @@ def __init__(self, websocket, task="transcribe", device=None, language=None, cli
compute_type="int8" if device == "cpu" else "float16",
local_files_only=False,
)
+ self.use_vad = use_vad
# threading
self.trans_thread = threading.Thread(target=self.speech_to_text)
@@ -776,8 +791,8 @@ def transcribe_audio(self, input_sample):
initial_prompt=self.initial_prompt,
language=self.language,
task=self.task,
- vad_filter=True,
- vad_parameters=self.vad_parameters)
+ vad_filter=self.use_vad,
+ vad_parameters=self.vad_parameters if self.use_vad else None)
if self.language is None:
self.set_language(info)
return result