diff --git a/.gitignore b/.gitignore
index 06b349b..725405d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,5 +1,4 @@
-start_container.sh
.env*
-test/*
tmp*
+test.log
__pycache__
\ No newline at end of file
diff --git a/README.md b/README.md
index 10f860f..492dd8c 100644
--- a/README.md
+++ b/README.md
@@ -7,6 +7,7 @@ LinTO-STT can either be used as a standalone transcription service or deployed w
The following families of STT models are currently supported (please refer to respective documentation for more details):
* [Kaldi models](kaldi/README.md)
* [Whisper models](whisper/README.md)
+* [Test scripts](test/README.md)
LinTO-STT can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector.
diff --git a/http_server/ingress.py b/http_server/ingress.py
index ec21d33..424e71c 100644
--- a/http_server/ingress.py
+++ b/http_server/ingress.py
@@ -9,7 +9,7 @@
from flask import Flask, json, request
from serving import GeventServing, GunicornServing
from stt import logger as stt_logger
-from stt.processing import MODEL, USE_GPU, decode, load_wave_buffer
+from stt.processing import MODEL, USE_GPU, decode, load_wave_buffer, warmup
from swagger import setupSwaggerUI
app = Flask("__stt-standalone-worker__")
@@ -24,7 +24,7 @@
logger.setLevel(logging.INFO)
# If websocket streaming route is enabled
-if os.environ.get("ENABLE_STREAMING", False) in [True, "true", 1]:
+if os.environ.get("ENABLE_STREAMING", "false").lower() in ["true", "1"]:
from flask_sock import Sock
from stt.processing.streaming import ws_streaming
@@ -84,7 +84,9 @@ def transcribe():
logger.error(traceback.format_exc())
logger.error(repr(error))
- return "Server Error: {}".format(str(error)), 400 if isinstance(error, ValueError) else 500
+ return "Server Error: {}".format(str(error)), (
+ 400 if isinstance(error, ValueError) else 500
+ )
@app.errorhandler(405)
@@ -128,12 +130,18 @@ def server_error(error):
serving_type = GunicornServing
logger.debug("Serving with gunicorn")
+ def post_worker_init(worker):
+ logger.info(f"Worker {worker.pid} init")
+ warmup()
+ logger.info(f"Worker {worker.pid} fully initialized")
+
serving = serving_type(
app,
{
"bind": f"0.0.0.0:{args.service_port}",
"workers": args.workers,
"timeout": 3600 * 24,
+ "post_worker_init": post_worker_init,
},
)
logger.info(args)
diff --git a/kaldi/.envdefault b/kaldi/.envdefault
index 33a394c..f22fa40 100644
--- a/kaldi/.envdefault
+++ b/kaldi/.envdefault
@@ -7,8 +7,8 @@ ENABLE_STREAMING=true
# TASK PARAMETERS
SERVICE_NAME=stt
-SERVICES_BROKER=redis://192.168.0.1:6379
-BROKER_PASS=password
+SERVICES_BROKER=redis://172.17.0.1:6379
+BROKER_PASS=
# WEBSOCKET PARAMETERS
STREAMING_PORT=80
diff --git a/kaldi/README.md b/kaldi/README.md
index 7ebfa85..e7c2036 100644
--- a/kaldi/README.md
+++ b/kaldi/README.md
@@ -68,7 +68,7 @@ cp kaldi/.envdefault kaldi/.env
STT can be used three ways:
* Through an [HTTP API](#http-server) using the **http**'s mode.
-* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode.
+* Through a [message broker](#celery-task) using the **task**'s mode.
* Through a [websocket server](#websocket-server) **websocket**'s mode.
Mode is specified using the .env value or environment variable ```SERVING_MODE```.
@@ -99,7 +99,7 @@ This will run a container providing an [HTTP API](#http-api) binded on the host
| LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 |
| MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model |
-### Micro-service within LinTO-Platform stack
+### Celery task
The TASK serving mode connect a celery worker to a message broker.
The SERVICE_MODE value in the .env should be set to ```task```.
@@ -205,7 +205,10 @@ On a successfull transcription the returned object is a json object structured a
* The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False)
-## Test
+## Tests
+
+See [Test scripts](../test/README.md) for more details about testing.
+
### Curl
You can test you http API using curl:
```bash
diff --git a/kaldi/RELEASE.md b/kaldi/RELEASE.md
index 6ce152a..f5bc967 100644
--- a/kaldi/RELEASE.md
+++ b/kaldi/RELEASE.md
@@ -1,3 +1,6 @@
+# 1.0.2
+- Fix task mode for kaldi by updating SERVICES_BROKER and BROKER_PASS in .envdefault
+
# 1.0.1
- Fix streaming mode (websocket) in linto-stt-kaldi
diff --git a/kaldi/docker-entrypoint.sh b/kaldi/docker-entrypoint.sh
index 212b145..74d3b15 100755
--- a/kaldi/docker-entrypoint.sh
+++ b/kaldi/docker-entrypoint.sh
@@ -25,7 +25,7 @@ fi
# Launch parameters, environement variables and dependencies check
if [ -z "$SERVICE_MODE" ]
then
- echo "ERROR: Must specify a serving mode: [ http | task | websocket ]"
+ echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (None was specified)"
exit -1
else
if [ "$SERVICE_MODE" = "http" ]
@@ -48,7 +48,7 @@ else
echo "Running Websocket server on port ${STREAMING_PORT:=80}"
python websocket/websocketserver.py
else
- echo "ERROR: Wrong serving command: $1"
+ echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (got SERVICE_MODE=$SERVICE_MODE)"
exit -1
fi
fi
diff --git a/kaldi/stt/processing/__init__.py b/kaldi/stt/processing/__init__.py
index 9f99406..e0476a1 100644
--- a/kaldi/stt/processing/__init__.py
+++ b/kaldi/stt/processing/__init__.py
@@ -29,5 +29,8 @@
sys.exit(-1)
logger.info("Acoustic model and decoding graph loaded. (t={}s)".format(time() - start))
+def warmup():
+ pass
+
# Not implemented yet in Kaldi
USE_GPU = False
diff --git a/test/README.md b/test/README.md
new file mode 100644
index 0000000..1739688
--- /dev/null
+++ b/test/README.md
@@ -0,0 +1,70 @@
+# LinTO-STT-Tests
+
+## Use tests
+
+### HTTP - transcribe
+
+You can test your http server by using:
+
+```bash
+test_deployment.sh
+```
+
+> ⚠️ Be sure to check that you use the right port (default port for testing: 8080).
+
+### HTTP - streaming
+
+You can test your http streaming route by using:
+```bash
+test_streaming.py
+```
+Be sure to have a working microphone.
+> ⚠️ Be sure to check that you use the right port (default port for testing: 8080).
+
+If you want to test the streaming on a file:
+```bash
+test_streaming.py --audio_file bonjour.wav
+```
+
+### Task
+
+You can test your deployment of the task service mode by using:
+
+```bash
+test_celery.py AUDIO.wav
+```
+
+with AUDIO.wav the file you want to test on, for example, you can use bonjour.wav.
+
+> ⚠️ Be sure to check that you use the same port in your .env and in test_celery.py (default port for testing: 6379)
+
+
+## Unit tests
+
+You will need to install:
+```bash
+pip3 install ddt
+```
+
+To test the Kaldi models, you will need to download the models (see [Kaldi models](../kaldi/README.md)) and then fill the AM_PATH and LM_PATH fields in the [test_config.ini file](test_config.ini).
+> ⚠️ If you don't specify the models, the tests about Kaldi will fail.
+
+To launch the test you can do :
+```bash
+python test/test.py
+```
+
+> ⚠️ Be sure to launch it from the root folder of the repository.
+
+If you want the test to stop at the first fail use the -f flag:
+```bash
+python test/test.py -f
+```
+If you want to run a subset of test you can use -k with a part of a test name. for example only kaldi tests:
+```bash
+python test/test.py -k kaldi
+```
+or test with VAD=auditok, DEVICE=cuda:
+```bash
+python test/test.py -k VAD_auditok_DEVICE_cuda
+```
\ No newline at end of file
diff --git a/test/test.py b/test/test.py
new file mode 100644
index 0000000..f683583
--- /dev/null
+++ b/test/test.py
@@ -0,0 +1,318 @@
+import unittest
+import os
+import time
+import subprocess
+import requests
+import re
+from ddt import ddt, idata
+from pathlib import Path
+import warnings
+
+TESTDIR = os.path.dirname(os.path.realpath(__file__))
+ROOTDIR = os.path.dirname(TESTDIR)
+os.chdir(ROOTDIR)
+TESTDIR = os.path.basename(TESTDIR)
+
+
+
+def generate_whisper_test_setups():
+ dockerfiles = [
+ "whisper/Dockerfile.ctranslate2",
+ "whisper/Dockerfile.ctranslate2.cpu",
+ "whisper/Dockerfile.torch",
+ "whisper/Dockerfile.torch.cpu",
+ ]
+
+ servings = ["http", "task"]
+
+ vads = [None, "false", "auditok", "silero"]
+ devices = [None, "cpu", "cuda"]
+ models = ["tiny"]
+
+ for dockerfile in dockerfiles:
+ for device in devices:
+ for vad in vads:
+ for model in models:
+ for serving in servings:
+
+ # Test CPU dockerfile only on CPU
+ if dockerfile.endswith("cpu") and device != "cpu":
+ continue
+
+ # Do not test all VAD settings if not on CPU
+ if vad not in [None, "silero"]:
+ if device != "cpu":
+ continue
+
+ env_variables = ""
+ if vad:
+ env_variables += f"VAD={vad} "
+ if device:
+ env_variables += f"DEVICE={device} "
+ env_variables += f"MODEL={model}"
+
+ yield dockerfile, serving, env_variables
+
+def generate_kaldi_test_setups():
+ dockerfiles = ["kaldi/Dockerfile"]
+
+ servings = ["http", "task"]
+
+ for dockerfile in dockerfiles:
+ for serving in servings:
+ env_variables = ""
+ yield dockerfile, serving, env_variables
+
+def copy_env_file(env_file, env_variables=""):
+ env_variables = env_variables.split()
+ env_variables.append("SERVICE_MODE=")
+ with open(env_file, "r") as f:
+ lines = f.readlines()
+ with open(f"{TESTDIR}/.env", "w") as f:
+ for line in lines:
+ if not any([line.startswith(b.split("=")[0] + "=") for b in env_variables]):
+ f.write(line)
+
+@ddt
+class TestRunner(unittest.TestCase):
+
+ built_images = []
+ redis_launched = False
+
+ # def __init__(self, *args, **kwargs):
+ # super(TestRunner, self).__init__(*args, **kwargs)
+ # self.cleanup()
+
+ def echo_success(self, message):
+ print('\033[0;32m' + u'\u2714' + '\033[0m ' + message)
+
+ def echo_failure(self, message):
+ print('\033[0;31m' + u'\u2716' + '\033[0m ' + message)
+
+ def echo_note(self, message):
+ print(u'\u231B' + ' ' + message)
+
+ def echo_command(self, message):
+ print(f"$ {message}")
+
+ def report_failure(self, message, expect_failure=False):
+ if not expect_failure:
+ self.echo_failure(message)
+ self.cleanup()
+ if not expect_failure:
+ self.fail(message)
+ return message
+
+ def report_success(self):
+ self.echo_success("Test passed.")
+ self.cleanup()
+
+ def cleanup(self):
+ # Check if the container is running
+ p = subprocess.Popen(["docker", "ps", "-a"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, err = p.communicate()
+ if b"test_container" in out:
+ self.echo_command("docker stop test_container")
+ subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+ time.sleep(0.2) # Without this, the following tests can fail (The container name "/test_container" is already in use)
+
+ def process_output(self, p):
+ l = p.communicate()[0].decode('utf-8').replace('\n', '\n\t')
+ e = p.communicate()[1].decode('utf-8').replace('\n', '\n\t')
+ return f" \u2192 Log Message:\n\t{l}\n \u2192 Error Message:\n\t{e}"
+
+
+ def check_http_server_availability(self, server, pid):
+ total_wait_time = SERVER_STARTING_TIMEOUT # 10 minutes in seconds
+ retry_interval = 1 # Interval between attempts (in seconds)
+ elapsed_time = 0
+
+ while elapsed_time < total_wait_time:
+ try:
+ response = requests.head(server)
+ if response.status_code == 200:
+ self.echo_note(f"Server: {server} is available after {elapsed_time} sec.")
+ return
+ except requests.ConnectionError:
+ pass
+ if pid.poll() is not None:
+ return f"The server container has stopped for an unexpected reason.\n{self.process_output(pid)}"
+
+ time.sleep(retry_interval)
+ elapsed_time += retry_interval
+
+ return f"Server: {server} is not available after {total_wait_time} seconds, server launching must have failed.\n{self.process_output(pid)}"
+
+ def launch_redis(self):
+ if TestRunner.redis_launched:
+ return
+ cmd = "docker run --rm -p 6379:6379 --name test_redis redis/redis-stack-server:latest redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug"
+ self.echo_command(cmd)
+ p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ time.sleep(2)
+ if p.poll() is not None:
+ self.cleanup()
+ return f"Redis server failed to start.\n{self.process_output(p)}", None
+ TestRunner.redis_launched = True
+
+ def build_and_run_container(self, serving, docker_image, env_variables, use_local_cache):
+ self.echo_note(f"* Docker image: {docker_image}")
+ self.echo_note(f"* Options.....: {env_variables}")
+ build_args = ""
+ for i, env in enumerate(env_variables.split()):
+ if i>0 and env_variables.split()[i-1] =="-v":
+ build_args += f"-v {env} "
+ elif env=="-v":
+ continue
+ else:
+ build_args += f"--env {env} "
+ build_args += f"--env SERVICE_MODE={serving} "
+ if use_local_cache:
+ home = str(Path.home())
+ build_args += f"-v {home}/.cache:/root/.cache "
+
+ if serving == "task":
+ self.launch_redis()
+ build_args += "-v {}/:/opt/audio ".format(os.getcwd())
+
+ tag = f"test_{os.path.basename(docker_image)}"
+ if tag not in TestRunner.built_images:
+ # Only build images that have not been built yet
+ cmd = f'docker build . -f {docker_image} -t linto-stt-test:{tag}'
+ self.echo_command(cmd)
+ start_time = time.time()
+ p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ p.wait()
+ end_time = time.time()
+ if p.poll() != 0:
+ self.cleanup()
+ return f"Docker build failed.\n{self.process_output(p)}", None
+ self.echo_note(f"Docker image has been successfully built in {end_time - start_time:.0f} sec.")
+ TestRunner.built_images.append(tag)
+
+ cmd=f"docker run --rm -p 8080:80 --name test_container --env-file {TESTDIR}/.env --gpus all {build_args} linto-stt-test:{tag}"
+ self.echo_command(cmd)
+ p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ if p.poll() is not None:
+ return f"Docker container failed to start.\n{self.process_output(p)}", None
+ return None, p
+
+ def transcribe(self, command, regex, test_file, error_message, success_message, timeout=None):
+ start = time.time()
+ res = subprocess.run(command, shell=True, timeout=timeout, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ end = time.time()
+ if res.returncode != 0:
+ raise FileNotFoundError(f"Error: {res.stderr.decode('utf-8')}")
+ res = res.stdout.decode('utf-8')
+ if not re.search(regex, res):
+ message = f"{error_message}: The string '{res}' is not matching the regex ({regex}), the server didn't transcribe correctly."
+ return self.report_failure(message)
+ self.echo_note(f"{success_message} has transcribed {test_file} in {end - start:.0f} sec.")
+ return
+
+ def run_test(self, docker_image="whisper/Dockerfile.ctranslate2", serving="http", env_variables="", test_file=f"{TESTDIR}/bonjour.wav", use_local_cache=True, expect_failure=False):
+ warnings.simplefilter("ignore", ResourceWarning)
+ regex = ""
+ if os.path.basename(test_file) == "bonjour.wav":
+ regex = re.compile("[bB]onjour")
+ r, pid = self.build_and_run_container(serving, docker_image, env_variables, use_local_cache)
+ if r:
+ return self.report_failure(r, expect_failure=expect_failure)
+ if serving == "http":
+ r=self.check_http_server_availability("http://localhost:8080/healthcheck", pid)
+ if r:
+ return self.report_failure(r, expect_failure=expect_failure)
+ cmd = f'curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@{test_file};type=audio/wav"'
+ self.echo_command(cmd)
+ r = self.transcribe(cmd, regex, test_file, "Error transcription", "HTTP route 'transcribe'")
+ if r:
+ return self.report_failure(r, expect_failure=expect_failure)
+ cmd = f"python3 {TESTDIR}/test_streaming.py --audio_file {test_file}"
+ self.echo_command(cmd)
+ r = self.transcribe(cmd, regex, test_file, "Error streaming", "HTTP route 'streaming'")
+ elif serving == "task":
+ # you can be stuck here if the server crashed bc the task will be in the queue forever
+ cmd = f"python3 {TESTDIR}/test_celery.py {test_file}"
+ self.echo_command(cmd)
+ r = self.transcribe(cmd, regex, test_file, "Error task", "TASK route", timeout=60)
+ else:
+ raise RuntimeError(f"Unknown serving mode: {serving}")
+ if r:
+ return self.report_failure(r, expect_failure=expect_failure)
+ if not expect_failure:
+ self.report_success()
+ return ""
+
+ def setUp(self):
+ # Print an empty line because unittest prints the name of the test first, without a newline
+ print()
+ print("-"*70)
+
+ def tearDown(self):
+ print("-"*70)
+
+ @idata(generate_kaldi_test_setups())
+ def test_01_kaldi_integration(self, setup):
+ dockerfile, serving, env_variables = setup
+ if AM_PATH is None or LM_PATH is None or AM_PATH=="" or LM_PATH=="":
+ self.fail("AM or LM path not provided. Skipping kaldi test.")
+ if not os.path.exists(AM_PATH) or not os.path.exists(LM_PATH):
+ self.fail(f"AM or LM path not found: {AM_PATH} or {LM_PATH}")
+ copy_env_file("kaldi/.envdefault")
+ env_variables += f"-v {AM_PATH}:/opt/AM -v {LM_PATH}:/opt/LM"
+ self.run_test(dockerfile, serving=serving, env_variables=env_variables)
+
+
+ @idata(generate_whisper_test_setups())
+ def test_03_whisper_integration(self, setup):
+ dockerfile, serving, env_variables = setup
+ copy_env_file("whisper/.envdefault", env_variables)
+ self.run_test(dockerfile, serving=serving, env_variables=env_variables)
+
+ def test_02_whisper_failures_cuda_on_cpu_dockerfile(self):
+ env_variables = "MODEL=tiny DEVICE=cuda"
+ dockerfile = "whisper/Dockerfile.ctranslate2.cpu"
+ copy_env_file("whisper/.envdefault", env_variables)
+ self.assertIn("cannot open shared object file", self.run_test(dockerfile, env_variables=env_variables, expect_failure=True))
+
+ def test_02_whisper_failures_not_existing_file(self):
+ env_variables = "MODEL=tiny"
+ copy_env_file("whisper/.envdefault", env_variables)
+ with self.assertRaises(FileNotFoundError):
+ self.run_test(test_file="notexisting", env_variables=env_variables, expect_failure=True)
+ self.cleanup()
+
+ def test_02_whisper_failures_wrong_vad(self):
+ env_variables = "VAD=whatever MODEL=tiny"
+ copy_env_file("whisper/.envdefault", env_variables)
+ self.assertIn("Got unexpected VAD method whatever", self.run_test(env_variables=env_variables, expect_failure=True))
+
+ def test_04_model_whisper(self):
+ env_variables = "MODEL=small"
+ copy_env_file("whisper/.envdefault", env_variables)
+ self.run_test(env_variables=env_variables)
+
+def finalize_tests():
+ subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+ subprocess.run(["docker", "stop", "test_redis"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
+
+
+AM_PATH = None
+LM_PATH = None
+SERVER_STARTING_TIMEOUT = 60
+
+if __name__ == '__main__':
+ from configparser import ConfigParser
+ config = ConfigParser()
+
+ config.read(f"{TESTDIR}/test_config.ini")
+
+ SERVER_STARTING_TIMEOUT = int(config.get('server', 'STARTING_TIMEOUT')) if config.get('server', 'STARTING_TIMEOUT')!="" else SERVER_STARTING_TIMEOUT
+
+ AM_PATH = config.get('kaldi', 'AM_PATH')
+ LM_PATH = config.get('kaldi', 'LM_PATH')
+
+ try:
+ unittest.main(verbosity=2)
+ finally:
+ finalize_tests()
diff --git a/test/test_celery.py b/test/test_celery.py
new file mode 100755
index 0000000..59ed62e
--- /dev/null
+++ b/test/test_celery.py
@@ -0,0 +1,16 @@
+import sys
+from celery import Celery
+
+def transcribe_task(file_path):
+ celery = Celery(broker='redis://localhost:6379/0', backend='redis://localhost:6379/1')
+ r = celery.send_task(
+ 'transcribe_task',
+ (
+ file_path,
+ True,
+ ),
+ queue='stt')
+ return r.get()
+
+if __name__ == '__main__':
+ print(transcribe_task(sys.argv[1]))
\ No newline at end of file
diff --git a/test/test_config.ini b/test/test_config.ini
new file mode 100644
index 0000000..76bf72d
--- /dev/null
+++ b/test/test_config.ini
@@ -0,0 +1,6 @@
+[server]
+STARTING_TIMEOUT=60
+
+[kaldi]
+AM_PATH=
+LM_PATH=
\ No newline at end of file
diff --git a/test/test_deployment.sh b/test/test_deployment.sh
index b1b8d36..84daac8 100755
--- a/test/test_deployment.sh
+++ b/test/test_deployment.sh
@@ -1 +1 @@
-curl -X POST "http://localhost:8888/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@bonjour.wav;type=audio/wav"
+curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@bonjour.wav;type=audio/wav"
diff --git a/test/test_streaming.py b/test/test_streaming.py
new file mode 100644
index 0000000..d78f6cf
--- /dev/null
+++ b/test/test_streaming.py
@@ -0,0 +1,120 @@
+import asyncio
+import websockets
+import json
+import shutil
+
+def linstt_streaming(*kargs, **kwargs):
+ text = asyncio.run(_linstt_streaming(*kargs, **kwargs))
+ return text
+
+async def _linstt_streaming(
+ audio_file,
+ ws_api = "ws://localhost:8080/streaming",
+ verbose = False,
+):
+
+ if audio_file is None:
+ import pyaudio
+ # Init pyaudio
+ audio = pyaudio.PyAudio()
+ stream = audio.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=2048)
+ if verbose > 1:
+ print("Start recording")
+ else:
+ stream = open(audio_file, "rb")
+
+ alive = True
+ text = ""
+ partial = None
+
+ try:
+ async with websockets.connect(ws_api) as websocket:
+ await websocket.send(json.dumps({"config" : {"sample_rate": 16000 }}))
+ while alive:
+ try:
+ data = stream.read(32000)
+ if audio_file and not data:
+ if verbose > 1:
+ print("\nAudio file finished")
+ alive = False
+ await websocket.send(data)
+ res = await websocket.recv()
+ message = json.loads(res)
+ if message is None:
+ if verbose > 1:
+ print("\n Received None")
+ continue
+ if "partial" in message.keys():
+ partial = message["partial"]
+ if verbose:
+ print_partial(partial)
+ elif "text" in message.keys():
+ line = message["text"]
+ if verbose:
+ print_final(line)
+ if line:
+ if text:
+ text += "\n"
+ text += line
+ elif verbose:
+ print("???", message)
+ except KeyboardInterrupt:
+ if verbose > 1:
+ print("\nKeyboard interrupt")
+ alive = False
+ await websocket.send(json.dumps({"eof" : 1}))
+ res = await websocket.recv()
+ message = json.loads(res)
+ if isinstance(message, str):
+ message = json.loads(message)
+ if text:
+ text += " "
+ text += message["text"]
+ try:
+ res = await websocket.recv()
+ except websockets.ConnectionClosedOK:
+ if verbose > 1:
+ print("Websocket Closed")
+ except KeyboardInterrupt:
+ if verbose > 1:
+ print("\nKeyboard interrupt")
+ if verbose:
+ print_final("= FULL TRANSCRIPTION ", background="=")
+ print(text)
+
+ return text
+
+def print_partial(text):
+ text = text + "…"
+ terminal_size = shutil.get_terminal_size()
+ width = terminal_size.columns
+ start = ((len(text) - 1)// width) * width
+ if start > 0:
+ print(" "*width, end="\r")
+ if start < len(text) - 1:
+ print("…"+text[start+1:]+" "*(width-len(text)-start-1), end="\r")
+ else:
+ print(text[-width:], end="\r")
+ else:
+ print(text, end="\r")
+
+def print_final(text, background=" "):
+ terminal_size = shutil.get_terminal_size()
+ width = terminal_size.columns
+ print(background * width, end="\r")
+ print(text)
+
+if __name__ == "__main__":
+
+ import argparse
+ parser = argparse.ArgumentParser(description='Transcribe input streaming (from mic or a file) with LinSTT',
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument('--server', help='Transcription server',
+ default="ws://localhost:8080/streaming",
+ )
+ parser.add_argument("-v", "--verbose", action="store_true", help="Verbose mode")
+ parser.add_argument("--audio_file", default=None, help="A path to an audio file to transcribe (if not provided, use mic)")
+ args = parser.parse_args()
+
+ res = linstt_streaming(args.audio_file, args.server, verbose=2 if args.verbose else 1)
\ No newline at end of file
diff --git a/whisper/.envdefault b/whisper/.envdefault
index 75919c0..a8f8794 100644
--- a/whisper/.envdefault
+++ b/whisper/.envdefault
@@ -1,7 +1,7 @@
############################################
# SERVING PARAMETERS
############################################
-# "http" or "task"
+# "http" or "task" or "websocket"
SERVICE_MODE=http
# Below: used when SERVICE_MODE=task
@@ -9,6 +9,12 @@ SERVICE_NAME=stt
SERVICES_BROKER=redis://172.17.0.1:6379
BROKER_PASS=
+# HTTP PARAMETERS
+ENABLE_STREAMING=true
+
+# WEBSOCKET PARAMETERS
+STREAMING_PORT=80
+
############################################
# STT MODELING PARAMETERS
############################################
@@ -30,6 +36,15 @@ PROMPT=
# This option is experimental (and not implemented with ctranslate2).
# ALIGNMENT_MODEL=wav2vec
+# Voice Activity Detection (VAD) method
+# It can be either "0"/"false" (no VAD), "silero", or "1"/"true"/"auditok" (by default)
+# VAD=auditok
+
+# Voice Activity Detection (VAD) parameters
+# VAD_DILATATION=0.1
+# VAD_MIN_SPEECH_DURATION=0.1
+# VAD_MIN_SILENCE_DURATION=0.1
+
############################################
# EFFICIENCY PARAMETERS
############################################
@@ -40,7 +55,7 @@ PROMPT=
# CUDA_VISIBLE_DEVICES=0
# Number of threads per worker when running on CPU
-OMP_NUM_THREADS=4
+NUM_THREADS=4
-# Number of workers
+# Number of workers minus one (all except from the main one)
CONCURRENCY=2
diff --git a/whisper/Dockerfile.ctranslate2 b/whisper/Dockerfile.ctranslate2
index c2b3cd5..5fd3c53 100644
--- a/whisper/Dockerfile.ctranslate2
+++ b/whisper/Dockerfile.ctranslate2
@@ -15,6 +15,7 @@ COPY websocket /usr/src/app/websocket
COPY document /usr/src/app/document
COPY whisper/stt /usr/src/app/stt
COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+COPY test/bonjour.wav /usr/src/app/test/bonjour.wav
ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt"
diff --git a/whisper/Dockerfile.ctranslate2.cpu b/whisper/Dockerfile.ctranslate2.cpu
index df5eac7..1f0f40c 100644
--- a/whisper/Dockerfile.ctranslate2.cpu
+++ b/whisper/Dockerfile.ctranslate2.cpu
@@ -6,7 +6,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
# Install python dependencies
COPY whisper/requirements.ctranslate2.txt ./
RUN pip install --no-cache-dir -r requirements.ctranslate2.txt && rm requirements.ctranslate2.txt
-
WORKDIR /usr/src/app
COPY celery_app /usr/src/app/celery_app
@@ -15,6 +14,7 @@ COPY websocket /usr/src/app/websocket
COPY document /usr/src/app/document
COPY whisper/stt /usr/src/app/stt
COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+COPY test/bonjour.wav /usr/src/app/test/bonjour.wav
ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt"
diff --git a/whisper/Dockerfile.torch b/whisper/Dockerfile.torch
index 06b22f3..acd0b08 100644
--- a/whisper/Dockerfile.torch
+++ b/whisper/Dockerfile.torch
@@ -15,6 +15,7 @@ COPY websocket /usr/src/app/websocket
COPY document /usr/src/app/document
COPY whisper/stt /usr/src/app/stt
COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+COPY test/bonjour.wav /usr/src/app/test/bonjour.wav
ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt"
diff --git a/whisper/Dockerfile.torch.cpu b/whisper/Dockerfile.torch.cpu
index 17a3fb8..2d45336 100644
--- a/whisper/Dockerfile.torch.cpu
+++ b/whisper/Dockerfile.torch.cpu
@@ -12,7 +12,6 @@ RUN pip3 install \
# Install python dependencies
COPY whisper/requirements.torch.txt ./
RUN pip install --no-cache-dir -r requirements.torch.txt && rm requirements.torch.txt
-
WORKDIR /usr/src/app
COPY celery_app /usr/src/app/celery_app
@@ -21,6 +20,7 @@ COPY websocket /usr/src/app/websocket
COPY document /usr/src/app/document
COPY whisper/stt /usr/src/app/stt
COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./
+COPY test/bonjour.wav /usr/src/app/test/bonjour.wav
ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt"
diff --git a/whisper/README.md b/whisper/README.md
index 41dc46a..52d2122 100644
--- a/whisper/README.md
+++ b/whisper/README.md
@@ -114,17 +114,22 @@ cp whisper/.envdefault whisper/.env
| PARAMETER | DESCRIPTION | EXEMPLE |
|---|---|---|
-| SERVICE_MODE | STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` |
-| MODEL | Path to a Whisper model, type of Whisper model used, or HuggingFace identifier of a Whisper model. | `large-v3` \| `distil-whisper/distil-large-v2` \| \ \| ... |
-| LANGUAGE | (Optional) Language to recognize | `*` \| `fr` \| `fr-FR` \| `French` \| `en` \| `en-US` \| `English` \| ... |
-| PROMPT | (Optional) Prompt to use for the Whisper model | `some free text to encourage a certain transcription style (disfluencies, no punctuation, ...)` |
-| ALIGNMENT_MODEL | (Optional and deprecated) Path to the wav2vec model for word alignment, or name of HuggingFace repository or torchaudio pipeline | `WAV2VEC2_ASR_BASE_960H` \| `jonatasgrosman/wav2vec2-large-xlsr-53-english` \| \ \| ... |
-| DEVICE | (Optional) Device to use for the model | `cpu` \| `cuda` ... |
-| CUDA_VISIBLE_DEVICES | (Optional) GPU device index to use, if several. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... |
-| CONCURRENCY | Maximum number of parallel requests | `2` |
-| SERVICE_NAME | (For the task mode) queue's name for task processing | `my-stt` |
-| SERVICE_BROKER | (For the task mode) URL of the message broker | `redis://my-broker:6379` |
+| SERVICE_MODE | (Required) STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` |
+| MODEL | (Required) Path to a Whisper model, type of Whisper model used, or HuggingFace identifier of a Whisper model. | `large-v3` \| `distil-whisper/distil-large-v2` \| \ \| ... |
+| LANGUAGE | Language to recognize | `*` \| `fr` \| `fr-FR` \| `French` \| `en` \| `en-US` \| `English` \| ... |
+| PROMPT | Prompt to use for the Whisper model | `some free text to encourage a certain transcription style (disfluencies, no punctuation, ...)` |
+| DEVICE | Device to use for the model (by default, GPU/CUDA is used if it is available, CPU otherwise) | `cpu` \| `cuda` |
+| NUM_THREADS | Number of threads (maximum) to use for things running on CPU | `1` \| `4` \| ... |
+| CUDA_VISIBLE_DEVICES | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... |
+| CONCURRENCY | Maximum number of parallel requests (number of workers minus one) | `2` |
+| VAD | Voice Activity Detection method. Use "false" to disable. If not specified, the default is auditok VAD. | `true` \| `false` \| `1` \| `0` \| `auditok` \| `silero`
+| ENABLE_STREAMING | (For the http mode) enable the /streaming websocket route | `true\|false` |
+| STREAMING_PORT | (For the websocket mode) the listening port for ingoing WS connexions. | `80` |
+| SERVICE_NAME | (For the task mode only) queue's name for task processing | `my-stt` |
+| SERVICE_BROKER | (For the task mode only) URL of the message broker | `redis://my-broker:6379` |
| BROKER_PASS | (For the task mode only) broker password | `my-password` \| (empty) |
+| ALIGNMENT_MODEL | (Deprecated) Path to the wav2vec model for word alignment, or name of HuggingFace repository or torchaudio pipeline | `WAV2VEC2_ASR_BASE_960H` \| `jonatasgrosman/wav2vec2-large-xlsr-53-english` \| \ \| ... |
+
#### MODEL environment variable
@@ -184,7 +189,7 @@ and also `yue(cantonese)` since large-v3.
STT can be used in two ways:
* Through an [HTTP API](#http-server) using the **http**'s mode.
-* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode.
+* Through a [message broker](#celery-task) using the **task**'s mode.
Mode is specified using the .env value or environment variable ```SERVING_MODE```.
```bash
@@ -217,11 +222,11 @@ You may also want to add specific options:
| Variables | Description | Example |
|:-|:-|:-|
| `HOST_SERVING_PORT` | Host serving port | 8080 |
-| `` | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache |
+| `` | Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache |
| `` | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt |
-| `` | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec |
+| `` | Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec |
-### Micro-service within LinTO-Platform stack
+### Celery task
The TASK serving mode connect a celery worker to a message broker.
The SERVICE_MODE value in the .env should be set to ```task```.
@@ -248,10 +253,16 @@ You may also want to add specific options:
| Variables | Description | Example |
|:-|:-|:-|
| `` | Shared audio folder mounted to /opt/audio | /my/path/to/models/vosk-model |
-| `` | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache |
+| `` | Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache |
| `` | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt |
-| `` | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec |
+| `` | Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec |
+
+### Websocket Server
+Websocket server's mode deploy a streaming transcription service only.
+The SERVICE_MODE value in the .env should be set to ```websocket```.
+
+Usage is the same as the [http streaming API](#/streaming).
## Usages
### HTTP API
@@ -286,6 +297,18 @@ Return the transcripted text using "text/plain" or a json object when using "app
}
```
+#### /streaming
+The /streaming route is accessible if the ENABLE_STREAMING environment variable is set to true.
+
+The route accepts websocket connexions. Exchanges are structured as followed:
+1. Client send a json {"config": {"sample_rate":16000}}.
+2. Client send audio chunk (go to 3- ) or {"eof" : 1} (go to 5-).
+3. Server send either a partial result {"partial" : "this is a "} or a final result {"text": "this is a transcription"}.
+4. Back to 2-
+5. Server send a final result and close the connexion.
+
+> Connexion will be closed and the worker will be freed if no chunk are received for 10s.
+
#### /docs
The /docs route offers a OpenAPI/swagger interface.
@@ -320,13 +343,24 @@ On a successfull transcription the returned object is a json object structured a
* The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False)
-## Test
+## Tests
+
+See [Test scripts](../test/README.md) for more details about testing.
+
### Curl
-You can test you http API using curl:
+You can test your http API using curl:
+
```bash
curl -X POST "http://YOUR_SERVICE:YOUR_PORT/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@YOUR_FILE;type=audio/x-wav"
```
+### Streaming
+You can test your streaming API using a websocket:
+
+```bash
+python test/test_streaming.py --server ws://YOUR_SERVICE:YOUR_PORT/streaming --audio_file test/bonjour.wav
+```
+
## License
This project is developped under the AGPLv3 License (see LICENSE).
@@ -339,3 +373,4 @@ This project is developped under the AGPLv3 License (see LICENSE).
* [HuggingFace Transformers](https://github.com/huggingface/transformers)
* [SpeechBrain](https://github.com/speechbrain/speechbrain)
* [TorchAudio](https://github.com/pytorch/audio)
+* [Whisper_Streaming](https://github.com/ufal/whisper_streaming)
\ No newline at end of file
diff --git a/whisper/RELEASE.md b/whisper/RELEASE.md
index f54537e..84de80f 100644
--- a/whisper/RELEASE.md
+++ b/whisper/RELEASE.md
@@ -1,3 +1,10 @@
+# 1.0.3
+- Make Voice Activity Detection (VAD) configurable
+- Change default VAD from silero (neural approach) to auditok (heuristical approach), because silero can have unpredictable behaviour on different corner cases
+- Streaming support
+- New NUM_THREADS env variable to control the number of threads
+- Load the model when launching the service (not at the first request)
+
# 1.0.2
- ct2/faster_whisper: Upgrade faster_whisper and support recent distilled models
- ct2/faster_whisper: Fix possible gluing of different words together
diff --git a/whisper/docker-entrypoint.sh b/whisper/docker-entrypoint.sh
index 97a3804..09ea120 100755
--- a/whisper/docker-entrypoint.sh
+++ b/whisper/docker-entrypoint.sh
@@ -14,7 +14,7 @@ fi
# Launch parameters, environement variables and dependencies check
if [ -z "$SERVICE_MODE" ]
then
- echo "ERROR: Must specify a serving mode: [ http | task | websocket ]"
+ echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (None was specified)"
exit -1
else
if [ "$SERVICE_MODE" = "http" ]
@@ -41,9 +41,12 @@ else
/usr/src/app/wait-for-it.sh $(echo $SERVICES_BROKER | cut -d'/' -f 3) --timeout=20 --strict -- echo " $SERVICES_BROKER (Service Broker) is up" || exit 1
echo "RUNNING STT CELERY WORKER"
celery --app=celery_app.celeryapp worker $OPT -Ofair --queues=${SERVICE_NAME} -c ${CONCURRENCY} -n ${SERVICE_NAME}_worker@%h
-
+ elif [ "$SERVICE_MODE" == "websocket" ]
+ then
+ echo "Running Websocket server on port ${STREAMING_PORT:=80}"
+ python3 websocket/websocketserver.py
else
- echo "ERROR: Wrong serving command: $SERVICE_MODE"
+ echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (got SERVICE_MODE=$SERVICE_MODE)"
exit -1
fi
fi
diff --git a/whisper/requirements.ctranslate2.txt b/whisper/requirements.ctranslate2.txt
index e471fd7..87b5e80 100644
--- a/whisper/requirements.ctranslate2.txt
+++ b/whisper/requirements.ctranslate2.txt
@@ -11,6 +11,7 @@ regex
requests>=2.26.0
wavio>=0.0.4
websockets
+auditok
#faster_whisper==1.0.1
# This is version faster_whisper==1.0.1 + option for (persistent) prompt + fix for large-v3
git+https://github.com/linto-ai/faster-whisper.git
\ No newline at end of file
diff --git a/whisper/requirements.torch.txt b/whisper/requirements.torch.txt
index 3976414..e5f5f93 100644
--- a/whisper/requirements.torch.txt
+++ b/whisper/requirements.torch.txt
@@ -15,4 +15,5 @@ wavio>=0.0.4
websockets
whisper-timestamped
onnxruntime
-torchaudio
\ No newline at end of file
+torchaudio
+auditok
\ No newline at end of file
diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py
index f5551af..8bac458 100644
--- a/whisper/stt/__init__.py
+++ b/whisper/stt/__init__.py
@@ -12,6 +12,20 @@
# see https://github.com/guillaumekln/faster-whisper/issues/150
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # GPU in the right order
+if os.environ.get("VAD","auditok").lower() in ["true", "1"]:
+ VAD = "auditok"
+elif os.environ.get("VAD","auditok").lower() in ["false", "0"]:
+ VAD = False
+else:
+ VAD = os.environ.get("VAD","auditok")
+
+VAD_DILATATION = float(os.environ.get("VAD_DILATATION", 0.5))
+VAD_MIN_SPEECH_DURATION = float(os.environ.get("VAD_MIN_SPEECH_DURATION", 0.1))
+VAD_MIN_SILENCE_DURATION = float(os.environ.get("VAD_MAX_SILENCE_DURATION", 0.1))
+
+NUM_THREADS = os.environ.get("NUM_THREADS", os.environ.get("OMP_NUM_THREADS"))
+NUM_THREADS = int(NUM_THREADS)
+
try:
import faster_whisper
@@ -36,3 +50,21 @@
USE_TORCHAUDIO = True
except ImportError:
USE_TORCHAUDIO = False
+
+if USE_CTRANSLATE2:
+ def set_num_threads(n):
+ # os.environ["OMP_NUM_THREADS"] = str(n)
+ pass
+else:
+ import torch
+ DEFAULT_NUM_THREADS = torch.get_num_threads()
+ def set_num_threads(n):
+ torch.set_num_threads(n)
+
+# Number of CPU threads
+if NUM_THREADS is None:
+ NUM_THREADS = DEFAULT_NUM_THREADS
+if NUM_THREADS is not None:
+ NUM_THREADS = int(NUM_THREADS)
+# For Torch, we will set it afterward, because setting that before loading the model can hang the process (see https://github.com/pytorch/pytorch/issues/58962)
+set_num_threads(1)
diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py
index b0e7f6d..b82a482 100644
--- a/whisper/stt/processing/__init__.py
+++ b/whisper/stt/processing/__init__.py
@@ -2,7 +2,7 @@
import os
from lockfile import FileLock
-from stt import USE_CTRANSLATE2, logger
+from stt import USE_CTRANSLATE2, VAD, logger, set_num_threads, NUM_THREADS
from .alignment_model import get_alignment_model, load_alignment_model
from .decoding import decode
@@ -18,12 +18,19 @@
"USE_GPU",
]
-
+def warmup():
+ model.check_loaded()
+ audio_data = load_audiofile("test/bonjour.wav")
+ transcription = decode(audio_data, MODEL, False)
+ logger.info(f"Warmup result: {transcription}")
+
class LazyLoadedModel:
- def __init__(self, model_type, device):
+ def __init__(self, model_type, device, num_threads):
self.model_type = model_type
self.device = device
+ self.num_threads = num_threads
self._model = None
+ self.has_set_num_threads = False
def check_loaded(self):
if self._model is None:
@@ -31,12 +38,19 @@ def check_loaded(self):
with FileLock(lockfile):
self._model = load_whisper_model(self.model_type, device=self.device)
+ def check_num_threads(self):
+ if not self.has_set_num_threads and self.num_threads:
+ set_num_threads(self.num_threads)
+ self.has_set_num_threads = True
+
def __getattr__(self, name):
self.check_loaded()
+ self.check_num_threads()
return getattr(self._model, name)
def __call__(self, *args, **kwargs):
self.check_loaded()
+ self.check_num_threads()
return self._model(*args, **kwargs)
@@ -51,16 +65,8 @@ def __call__(self, *args, **kwargs):
language = get_language()
logger.info(f"Using language {language}")
-# Load ASR model
-model_type = os.environ.get("MODEL", "medium")
-logger.info(
- f"Loading Whisper model {model_type} ({'local' if os.path.exists(model_type) else 'remote'})..."
-)
-try:
- model = LazyLoadedModel(model_type, device=device)
- # model = load_whisper_model(model_type, device=device)
-except Exception as err:
- raise Exception("Failed to load transcription model: {}".format(str(err))) from err
+logger.info(f"VAD={VAD}")
+logger.info(f"USE_CTRANSLATE2={USE_CTRANSLATE2}")
# Load alignment model (if any)
alignment_model = get_alignment_model(os.environ.get("alignment_model"), language)
@@ -77,4 +83,19 @@ def __call__(self, *args, **kwargs):
)
alignment_model = {} # Alignement model(s) will be loaded on the fly
-MODEL = (model, alignment_model)
+
+# Load ASR model
+model_type = os.environ.get("MODEL", "medium")
+logger.info(
+ f"Loading Whisper model {model_type} ({'local' if os.path.exists(model_type) else 'remote'})..."
+)
+try:
+ model = LazyLoadedModel(model_type, device=device, num_threads=NUM_THREADS)
+ MODEL = (model, alignment_model)
+ if USE_GPU:
+ warmup()
+except Exception as err:
+ raise Exception("Failed to load transcription model: {}".format(str(err))) from err
+
+
+
\ No newline at end of file
diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py
index 5f032e4..9ead692 100644
--- a/whisper/stt/processing/decoding.py
+++ b/whisper/stt/processing/decoding.py
@@ -5,8 +5,9 @@
from typing import Tuple, Union
import numpy as np
-from stt import USE_CTRANSLATE2, logger
+from stt import USE_CTRANSLATE2, VAD, VAD_DILATATION, VAD_MIN_SILENCE_DURATION, VAD_MIN_SPEECH_DURATION, logger
+from .vad import remove_non_speech
from .alignment_model import get_alignment_model, load_alignment_model
from .text_normalize import normalize_text, remove_emoji, remove_punctuation
from .utils import SAMPLE_RATE, get_language
@@ -17,7 +18,6 @@
import whisper_timestamped
USE_ACCURATE = True
-USE_VAD = True
if USE_ACCURATE:
default_beam_size = 5
@@ -47,7 +47,6 @@ def decode(
) -> dict:
if language is None:
language = get_language()
-
kwargs = copy.copy(locals())
kwargs.pop("model_and_alignementmodel")
kwargs["model"], kwargs["alignment_model"] = model_and_alignementmodel
@@ -64,7 +63,6 @@ def decode(
kwargs.pop("alignment_model")
res = decode_ct2(**kwargs)
else:
- print("OK")
res = decode_torch(**kwargs)
logger.info("Transcription complete (t={}s)".format(time.time() - start_t))
@@ -73,24 +71,30 @@ def decode(
def decode_ct2(
- audio, model, with_word_timestamps, language, remove_punctuation_from_words, **kwargs
+ audio,
+ model,
+ with_word_timestamps,
+ language,
+ remove_punctuation_from_words,
+ **kwargs,
):
kwargs["no_speech_threshold"] = 1 # To avoid empty output
if kwargs.get("beam_size") is None:
kwargs["beam_size"] = 1
if kwargs.get("best_of") is None:
kwargs["best_of"] = 1
-
+ if VAD:
+ _, speech_segments, _ = remove_non_speech(audio, use_sample=True, method=VAD, dilatation=VAD_DILATATION, \
+ min_silence_duration=VAD_MIN_SILENCE_DURATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, return_format="dict")
segments, info = model.transcribe(
audio,
word_timestamps=with_word_timestamps,
language=language,
# Careful with the following options
max_initial_timestamp=10000.0,
- vad_filter=USE_VAD,
+ vad_filter=speech_segments if VAD else False,
**kwargs,
)
-
segments = list(segments)
return format_faster_whisper_response(
@@ -118,6 +122,10 @@ def decode_torch(
fp16 = model.device != torch.device("cpu")
+ if VAD:
+ _, speech_segments, _ = remove_non_speech(audio, use_sample=True, method=VAD, dilatation=VAD_DILATATION, \
+ min_silence_duration=VAD_MIN_SILENCE_DURATION, min_speech_duration=VAD_MIN_SPEECH_DURATION,)
+
kwargs = dict(
language=language,
fp16=fp16,
@@ -127,13 +135,15 @@ def decode_torch(
condition_on_previous_text=condition_on_previous_text,
no_speech_threshold=no_speech_threshold,
compression_ratio_threshold=compression_ratio_threshold,
- vad=USE_VAD,
+ vad=speech_segments if VAD else False,
initial_prompt=prompt,
)
if alignment_model is None:
# Use Whisper cross-attention weights
- whisper_res = whisper_timestamped.transcribe(model, audio, verbose=None, **kwargs)
+ whisper_res = whisper_timestamped.transcribe(
+ model, audio, verbose=None, **kwargs
+ )
if language is None:
language = whisper_res["language"]
logger.info(f"Detected language: {language}")
@@ -175,7 +185,9 @@ def decode_torch(
result["text"] = text
result["language"] = language
result["confidence-score"] = (
- np.exp(np.array([r["avg_logprob"] for r in segments])).mean() if len(segments) else 0.0
+ np.exp(np.array([r["avg_logprob"] for r in segments])).mean()
+ if len(segments)
+ else 0.0
)
if not with_word_timestamps:
@@ -251,7 +263,9 @@ def decode_torch(
return result
-def format_whisper_timestamped_response(transcription, remove_punctuation_from_words=False):
+def format_whisper_timestamped_response(
+ transcription, remove_punctuation_from_words=False
+):
"""Format Whisper response."""
for i, seg in enumerate(transcription["segments"][:-1]):
@@ -281,9 +295,11 @@ def format_whisper_timestamped_response(transcription, remove_punctuation_from_w
return {
"text": transcription["text"].strip(),
"language": transcription["language"],
- "confidence-score": round(np.exp(np.array([r["avg_logprob"] for r in segments])).mean(), 2)
- if len(segments)
- else 0.0,
+ "confidence-score": (
+ round(np.exp(np.array([r["avg_logprob"] for r in segments])).mean(), 2)
+ if len(segments)
+ else 0.0
+ ),
"words": words,
}
@@ -307,7 +323,10 @@ def checked_timestamps(start, end=None):
if end == start:
pass # end = start + 0.01
else:
- print("WARNING, end timestamp %f is smaller than start timestamp %f" % (end, start))
+ print(
+ "WARNING, end timestamp %f is smaller than start timestamp %f"
+ % (end, start)
+ )
if end is None:
return start
return (start, end)
@@ -327,7 +346,11 @@ def checked_timestamps(start, end=None):
and len(words)
and len(word_strip) > 1
and word_strip[0] in glue_punctuations
- and (word_strip == word_string or not contains_alphanum(words[-1]["text"]) or not contains_alphanum(word_strip))
+ and (
+ word_strip == word_string
+ or not contains_alphanum(words[-1]["text"])
+ or not contains_alphanum(word_strip)
+ )
):
words[-1]["text"] += word_strip
words[-1]["confidence"].append(word.probability)
@@ -368,5 +391,6 @@ def checked_timestamps(start, end=None):
transcription, remove_punctuation_from_words=remove_punctuation_from_words
)
+
def contains_alphanum(text: str) -> bool:
- return re.search(r"[^\W\'\-_]", text)
\ No newline at end of file
+ return re.search(r"[^\W\'\-_]", text)
diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py
new file mode 100644
index 0000000..7d2efce
--- /dev/null
+++ b/whisper/stt/processing/streaming.py
@@ -0,0 +1,517 @@
+import json
+import sys
+import string
+import numpy as np
+from .vad import remove_non_speech
+from stt import logger, USE_CTRANSLATE2, VAD, VAD_DILATATION, VAD_MIN_SPEECH_DURATION, VAD_MIN_SILENCE_DURATION
+from websockets.legacy.server import WebSocketServerProtocol
+from simple_websocket.ws import Server as WSServer
+
+
+def bytes_to_array(bytes):
+ return np.frombuffer(bytes, dtype=np.int16).astype(np.float32) / 32768
+
+
+def processor_output_to_text(o):
+ if o[0] is None:
+ return ""
+ return o[2]
+
+
+def whisper_to_json(o):
+ result = dict()
+ result["text"] = processor_output_to_text(o)
+ json_res = json.dumps(result)
+ return json_res
+
+
+async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel):
+ """Async Decode function endpoint"""
+ res = await ws.recv()
+ try:
+ config = json.loads(res)["config"]
+ sample_rate = config["sample_rate"]
+ logger.info(f"Received config: {config}")
+ except Exception as e:
+ logger.error("Failed to read stream configuration")
+ await ws.close(reason="Failed to load configuration")
+ model, _ = model_and_alignementmodel
+ if USE_CTRANSLATE2:
+ logger.info("Using ctranslate2 for decoding")
+ asr = FasterWhisperASR(model=model, lan="fr")
+ else:
+ logger.info("Using whisper_timestamped for decoding")
+ asr = WhisperTimestampedASR(model=model, lan="fr")
+ online = OnlineASRProcessor(
+ asr, logfile=sys.stderr, buffer_trimming=8, vad=VAD, sample_rate=sample_rate, \
+ dilatation=VAD_DILATATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, min_silence_duration=VAD_MIN_SILENCE_DURATION
+ )
+ logger.info("Starting transcription ...")
+ while True:
+ try:
+ message = await ws.recv()
+ if message is None or message == "": # Timeout
+ logger.info("Connection closed by client")
+ ws.close()
+ except Exception as e:
+ logger.info(f"Connection closed by client: {e}")
+ break
+ if "eof" in str(message):
+ o = online.finish()
+ await ws.send(whisper_to_json(o))
+ logger.info(f"End of stream {message}")
+ await ws.close(reason="End of stream")
+ break
+ online.insert_audio_chunk(bytes_to_array(message))
+ o, _ = online.process_iter()
+ logger.info(o)
+ await ws.send(whisper_to_json(o))
+
+
+def ws_streaming(websocket_server: WSServer, model_and_alignementmodel):
+ """Sync Decode function endpoint"""
+ res = websocket_server.receive(timeout=10)
+ try:
+ config = json.loads(res)["config"]
+ sample_rate = config["sample_rate"]
+ logger.info(f"Received config: {config}")
+ except Exception as e:
+ logger.error("Failed to read stream configuration")
+ websocket_server.close()
+ model, _ = model_and_alignementmodel
+ if USE_CTRANSLATE2:
+ logger.info("Using ctranslate2 for decoding")
+ asr = FasterWhisperASR(model=model, lan="fr")
+ else:
+ logger.info("Using whisper_timestamped for decoding")
+ asr = WhisperTimestampedASR(model=model, lan="fr")
+ online = OnlineASRProcessor(
+ asr, logfile=sys.stderr, buffer_trimming=8, vad=VAD, sample_rate=sample_rate, \
+ dilatation=VAD_DILATATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, min_silence_duration=VAD_MIN_SILENCE_DURATION
+ )
+ logger.info("Starting transcription ...")
+ while True:
+ try:
+ message = websocket_server.receive(timeout=10)
+ if message is None or message == "": # Timeout
+ logger.info("Connection closed by client")
+ websocket_server.close()
+ except Exception as e:
+ logger.info(f"Connection closed by client: {e}")
+ break
+ if "eof" in str(message):
+ o = online.finish()
+ websocket_server.send(whisper_to_json(o))
+ logger.info(f"End of stream {message}")
+ websocket_server.close()
+ break
+ online.insert_audio_chunk(bytes_to_array(message))
+ o, _ = online.process_iter()
+ websocket_server.send(whisper_to_json(o))
+
+
+class HypothesisBuffer:
+
+ def __init__(self, logfile=sys.stderr):
+ self.commited_in_buffer = []
+ self.buffer = []
+ self.new = []
+
+ self.last_commited_time = 0
+ self.last_commited_word = None
+ self.last_buffered_time = -1
+
+ self.logfile = logfile
+
+ def insert(self, new, offset):
+ # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content
+ # the new tail is added to self.new
+
+ new = [(a + offset, b + offset, t) for a, b, t in new]
+ self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1]
+
+ if len(self.new) >= 1:
+ a, b, t = self.new[0]
+ if abs(a - self.last_commited_time) < 1:
+ if self.commited_in_buffer:
+ # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped.
+ cn = len(self.commited_in_buffer)
+ nn = len(self.new)
+ for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum
+ c = " ".join(
+ [self.commited_in_buffer[-j][2] for j in range(1, i + 1)][
+ ::-1
+ ]
+ )
+ tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1))
+ if c == tail:
+ logger.debug(f"removing last {i} words:")
+ for j in range(i):
+ logger.debug(f"\t{self.new.pop(0)}")
+ break
+
+ def flush(self):
+ # returns commited chunk = the longest common prefix of 2 last inserts.
+ commit = []
+ while self.new:
+ na, nb, nt = self.new[0]
+
+ if len(self.buffer) == 0:
+ break
+
+ if nt.lower().translate(
+ str.maketrans("", "", string.punctuation)
+ ) == self.buffer[0][2].lower().translate(
+ str.maketrans("", "", string.punctuation)
+ ):
+ commit.append((na, nb, nt))
+ self.last_commited_word = nt
+ self.last_commited_time = nb
+ self.buffer.pop(0)
+ self.new.pop(0)
+ else:
+ break
+ self.buffer = self.new
+ new_non_commit = [
+ i for i in self.buffer if i[1] > self.last_buffered_time - 0.1
+ ]
+ self.last_buffered_time = self.buffer[-1][1] if self.buffer else -1
+ self.new = []
+ self.commited_in_buffer.extend(commit)
+ return commit, new_non_commit
+
+ def pop_commited(self, time):
+ while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time:
+ self.commited_in_buffer.pop(0)
+
+ def complete(self):
+ return self.buffer
+
+
+class OnlineASRProcessor:
+
+ def __init__(
+ self,
+ asr,
+ buffer_trimming=15,
+ vad="auditok",
+ logfile=sys.stderr,
+ sample_rate=16000,
+ min_speech_duration=0.1,
+ min_silence_duration=0.1,
+ dilatation=0.5,
+ ):
+ """asr: WhisperASR object
+ tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all.
+ ("segment", 15)
+ buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option.
+ logfile: where to store the log.
+ """
+ self.asr = asr
+ self.logfile = logfile
+
+ self.init()
+
+ self.buffer_trimming_sec = buffer_trimming
+ self.vad = vad
+ self.vad_dilatation = dilatation
+ self.vad_min_speech_duration = min_speech_duration
+ self.vad_min_silence_duration = min_silence_duration
+ self.sampling_rate = sample_rate
+
+ def init(self):
+ """run this when starting or restarting processing"""
+ self.audio_buffer = np.array([], dtype=np.float32)
+ self.buffer_time_offset = 0
+
+ self.transcript_buffer = HypothesisBuffer(logfile=self.logfile)
+ self.commited = []
+ self.last_chunked_at = 0
+
+ self.silence_iters = 0
+
+ def insert_audio_chunk(self, audio):
+ self.audio_buffer = np.append(self.audio_buffer, audio)
+
+ def prompt(self):
+ """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer.
+ "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons.
+ """
+ k = max(0, len(self.commited) - 1)
+ while k > 0 and self.commited[k - 1][1] > self.last_chunked_at:
+ k -= 1
+
+ p = self.commited[:k]
+ p = [t for _, _, t in p]
+ prompt = []
+ l = 0
+ while p and l < 200: # 200 characters prompt size
+ x = p.pop(-1)
+ l += len(x) + 1
+ prompt.append(x)
+ non_prompt = self.commited[k:]
+ return self.asr.sep.join(prompt[::-1]), self.asr.sep.join(
+ t for _, _, t in non_prompt
+ )
+
+ def process_iter(self):
+ """Runs on the current audio buffer.
+ Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, "").
+ The non-emty text is confirmed (committed) partial transcript.
+ """
+ prompt, non_prompt = self.prompt()
+ logger.debug(f"PROMPT:{prompt}")
+ logger.debug(f"CONTEXT:{non_prompt}")
+ logger.debug(
+ f"Transcribing {len(self.audio_buffer)/self.sampling_rate:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s"
+ )
+ if self.vad:
+ np_buffer = np.array(self.audio_buffer)
+ audio_speech, segments, convertion_function = remove_non_speech(
+ np_buffer,
+ method=self.vad,
+ use_sample=True,
+ sample_rate=self.sampling_rate,
+ dilatation=self.vad_dilatation,
+ min_speech_duration=self.vad_min_speech_duration,
+ min_silence_duration=self.vad_min_silence_duration,
+ )
+ res = self.asr.transcribe(audio_speech, init_prompt=prompt)
+ else:
+ res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt)
+ # transform to [(beg,end,"word1"), ...]
+ tsw = self.asr.ts_words(res, convertion_function if self.vad else None)
+ self.transcript_buffer.insert(tsw, self.buffer_time_offset)
+ o, buffer = self.transcript_buffer.flush()
+ self.commited.extend(o)
+ if (
+ buffer
+ and (self.buffer_time_offset + len(self.audio_buffer) / self.sampling_rate)
+ - buffer[-1][1]
+ < 0.05
+ ):
+ # remove the last word if it is too close to the end of the buffer
+ buffer.pop(-1)
+ logger.debug(f"New committed text:{self.to_flush(o)}")
+ logger.debug(
+ f"Buffered text:{self.to_flush(self.transcript_buffer.complete())}"
+ )
+
+ if len(self.audio_buffer) / self.sampling_rate > self.buffer_trimming_sec:
+ self.chunk_completed_segment(
+ res,
+ chunk_silence=self.vad,
+ speech_segments=segments if self.vad else False,
+ )
+
+ logger.debug(
+ f"Len of buffer now: {len(self.audio_buffer)/self.sampling_rate:2.2f}s"
+ )
+ return self.to_flush(o), self.to_flush(buffer)
+
+ def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None):
+ if self.commited == [] and not chunk_silence:
+ return
+ ends = self.asr.segments_end_ts(res)
+ t = self.commited[-1][1]
+ if len(ends) > 1:
+ e = ends[-2] + self.buffer_time_offset
+ while len(ends) > 2 and e > t:
+ ends.pop(-1)
+ e = ends[-2] + self.buffer_time_offset
+ if e <= t:
+ logger.debug(f"--- segment chunked at {e:2.2f}")
+ self.chunk_at(e)
+ else:
+ logger.debug(f"--- last segment not within commited area")
+ elif chunk_silence:
+ lenght = len(self.audio_buffer) / self.sampling_rate
+ e = self.buffer_time_offset + lenght - 2
+ if speech_segments:
+ end_silence = lenght - speech_segments[-1][1]
+ if end_silence > 2:
+ logger.debug(f"--- Silence segment chunked at {e:2.2f}")
+ self.chunk_at(e)
+ elif speech_segments is not None:
+ logger.debug(f"--- Silence segment chunked at {e:2.2f}")
+ self.chunk_at(e)
+ else:
+ logger.debug(f"--- not enough segments to chunk")
+
+ def chunk_at(self, time):
+ """trims the hypothesis and audio buffer at "time" """
+ self.transcript_buffer.pop_commited(time)
+ cut_seconds = time - self.buffer_time_offset
+ self.audio_buffer = self.audio_buffer[int(cut_seconds * self.sampling_rate) :]
+ self.buffer_time_offset = time
+ self.last_chunked_at = time
+
+ def words_to_sentences(self, words):
+ """Uses self.tokenizer for sentence segmentation of words.
+ Returns: [(beg,end,"sentence 1"),...]
+ """
+
+ cwords = [w for w in words]
+ t = " ".join(o[2] for o in cwords)
+ s = self.tokenizer.split(t)
+ out = []
+ while s:
+ beg = None
+ end = None
+ sent = s.pop(0).strip()
+ fsent = sent
+ while cwords:
+ b, e, w = cwords.pop(0)
+ w = w.strip()
+ if beg is None and sent.startswith(w):
+ beg = b
+ elif end is None and sent == w:
+ end = e
+ out.append((beg, end, fsent))
+ break
+ sent = sent[len(w) :].strip()
+ return out
+
+ def finish(self):
+ """Flush the incomplete text when the whole processing ends.
+ Returns: the same format as self.process_iter()
+ """
+ o = self.transcript_buffer.complete()
+ f = self.to_flush(o)
+ logger.debug(f"last, noncommited:{f}")
+ return f
+
+ def to_flush(
+ self,
+ sents,
+ sep=None,
+ offset=0,
+ ):
+ # concatenates the timestamped words or sentences into one sequence that is flushed in one line
+ # sents: [(beg1, end1, "sentence1"), ...] or [] if empty
+ # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty
+ if sep is None:
+ sep = self.asr.sep
+ t = sep.join(s[2] for s in sents)
+ if len(sents) == 0:
+ b = None
+ e = None
+ else:
+ b = offset + sents[0][0]
+ e = offset + sents[-1][1]
+ return (b, e, t)
+
+
+class ASRBase:
+
+ sep = " " # join transcribe words with this character (" " for whisper_timestamped,
+ # "" for faster-whisper because it emits the spaces when needed)
+
+ def __init__(
+ self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None
+ ):
+ self.logfile = logfile
+
+ self.transcribe_kargs = {}
+ self.original_language = lan
+ self.model = model
+
+ def transcribe(self, audio, init_prompt=""):
+ raise NotImplemented("must be implemented in the child class")
+
+ def use_vad(self, vad_name=None):
+ raise NotImplemented("must be implemented in the child class")
+
+
+class FasterWhisperASR(ASRBase):
+ """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version."""
+
+ sep = ""
+
+ def __init__(
+ self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None
+ ):
+ super().__init__(lan, model=model, logfile=logfile)
+ self.transcribe_kargs["beam_size"] = 1
+ self.transcribe_kargs["best_of"] = 1
+ self.transcribe_kargs["temperature"] = 0
+ self.transcribe_kargs["condition_on_previous_text"] = (
+ False if condition_on_previous_text is None else condition_on_previous_text
+ )
+
+ def transcribe(self, audio, init_prompt=""):
+ segments, info = self.model.transcribe(
+ audio,
+ language=self.original_language,
+ initial_prompt=init_prompt,
+ word_timestamps=True,
+ **self.transcribe_kargs,
+ )
+ return list(segments)
+
+ def ts_words(self, segments, timestamps_convert_function=None):
+ o = []
+ for segment in segments:
+ for word in segment.words:
+ # not stripping the spaces -- should not be merged with them!
+ w = word.word
+ if timestamps_convert_function is not None:
+ start, end = timestamps_convert_function(word.start, word.end)
+ t = (start, end, w)
+ else:
+ t = (word.start, word.end, w)
+ o.append(t)
+ return o
+
+ def segments_end_ts(self, res):
+ return [s.end for s in res]
+
+
+class WhisperTimestampedASR(ASRBase):
+ """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper.
+ On the other hand, the installation for GPU could be easier.
+ """
+
+ sep = " "
+
+ def __init__(
+ self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None
+ ):
+ super().__init__(lan, model=model, logfile=logfile)
+ self.transcribe_kargs["verbose"] = None
+ self.transcribe_kargs["beam_size"] = None
+ self.transcribe_kargs["best_of"] = None
+ self.transcribe_kargs["temperature"] = 0
+ self.transcribe_kargs["condition_on_previous_text"] = (
+ False if condition_on_previous_text is None else condition_on_previous_text
+ )
+ from whisper_timestamped import transcribe_timestamped
+
+ self.transcribe_timestamped = transcribe_timestamped
+
+ def transcribe(self, audio, init_prompt=""):
+ result = self.transcribe_timestamped(
+ self.model,
+ audio,
+ language=self.original_language,
+ initial_prompt=init_prompt,
+ **self.transcribe_kargs,
+ )
+ return result
+
+ def ts_words(self, r, timestamps_convert_function=None):
+ # return: transcribe result object to [(beg,end,"word1"), ...]
+ o = []
+ for s in r["segments"]:
+ for w in s["words"]:
+ if timestamps_convert_function is not None:
+ # print(f"start: {word.start}->{timestamps_convert_function(word.start)}, end: {word.end}->{timestamps_convert_function(word.end)}")
+ start, end = timestamps_convert_function(w["start"], w["end"])
+ t = (start, end, w["text"])
+ else:
+ t = (w["start"], w["end"], w["text"])
+ o.append(t)
+ return o
+
+ def segments_end_ts(self, res):
+ return [s["end"] for s in res["segments"]]
diff --git a/whisper/stt/processing/vad.py b/whisper/stt/processing/vad.py
new file mode 100644
index 0000000..7137410
--- /dev/null
+++ b/whisper/stt/processing/vad.py
@@ -0,0 +1,377 @@
+import numpy as np
+import os
+import shutil
+from stt import logger, USE_CTRANSLATE2
+
+
+_silero_vad_model = {}
+_has_onnx = None
+
+
+def remove_non_speech(
+ audio,
+ use_sample=False,
+ min_speech_duration=0.1,
+ min_silence_duration=1,
+ dilatation=0.5,
+ sample_rate=16000,
+ method="auditok",
+ avoid_empty_speech=False,
+ return_format="tuple",
+):
+ """
+ Remove non-speech segments from audio (using Silero VAD),
+ glue the speech segments together and return the result along with
+ a function to convert timestamps from the new audio to the original audio
+
+ parameters:
+ audio: torch.Tensor
+ audio data *in 16kHz*
+ use_sample: bool
+ if True, return start and end in samples instead of seconds
+ min_speech_duration: float
+ minimum duration (in sec) of a speech segment
+ min_silence_duration: float
+ minimum duration (in sec) of a silence segment
+ dilatation: float
+ how much (in sec) to enlarge each speech segment detected by the VAD
+ method: str
+ method to use to remove non-speech segments
+ avoid_empty_speech: bool
+ if True, avoid returning an empty speech segment (re)
+ """
+
+ if USE_CTRANSLATE2 and method == "silero":
+ from faster_whisper.vad import VadOptions
+
+ options = VadOptions(
+ min_speech_duration_ms=min_speech_duration * 1000,
+ min_silence_duration_ms=min_silence_duration * 1000,
+ )
+ from faster_whisper.vad import get_speech_timestamps
+
+ segments = get_speech_timestamps(audio, vad_options=options)
+ else:
+ segments = get_vad_segments(
+ audio,
+ sample_rate=sample_rate,
+ min_speech_duration=min_speech_duration,
+ min_silence_duration=min_silence_duration,
+ method=method,
+ )
+ segments = apply_dilatation(segments, dilatation, sample_rate, audio, output_sample=True)
+ segments = [(seg["start"], seg["end"]) for seg in segments]
+ if len(segments) == 0:
+ if avoid_empty_speech:
+ segments = [(0, audio.shape[-1])]
+ else:
+ np.array([]), [], lambda t, t2=None: t if t2 is None else [t, t2]
+ if not use_sample:
+ segments = [
+ (float(s) / sample_rate, float(e) / sample_rate) for s, e in segments
+ ]
+
+ if return_format == "dict":
+ segments = [{"start": s, "end": e} for s, e in segments]
+ return None, segments, lambda t, t2=None: do_convert_timestamps(segments, t, t2)
+
+ audio_speech = np.concatenate([audio[..., s:e] for s, e in segments], axis=-1)
+
+ return audio_speech, segments, lambda t, t2=None: do_convert_timestamps(segments, t, t2)
+
+
+def do_convert_timestamps(segments, t, t2=None):
+ """
+ Convert timestamp from audio without non-speech segments to original audio (with non-speech segments)
+
+ parameters:
+ segments: list of tuple (start, end) corresponding to non-speech segments in original audio
+ t: timestamp to convert
+ t2: second timestamp to convert (optional), when the two timestamps should be in the same segment
+ """
+ assert len(segments)
+ ioffset = 0 # Input offset
+ ooffset = 0 # Output offset
+ ipreviousend = 0
+ result = []
+ for istart, iend in segments:
+ ostart = ooffset
+ oend = ostart + (iend - istart)
+ ooffset = oend
+ ioffset += istart - ipreviousend
+ ipreviousend = iend
+ t_in = t <= oend
+ t2_in = t_in if t2 is None else t2 <= oend
+ if t_in or t2_in:
+ result.append(
+ [
+ max(istart, min(iend, ioffset + t)),
+ max(istart, min(iend, ioffset + t2)) if t2 is not None else None,
+ ]
+ )
+ if t_in and t2_in:
+ break
+ if not len(result):
+ result.append([ioffset + t, ioffset + t2 if t2 is not None else None])
+
+ if len(result) > 1:
+ # Minimize difference between durations
+ result = sorted(result, key=lambda x: abs(abs(t2 - t) - abs(x[1] - x[0])))
+ result = result[0]
+ if t2 is None:
+ result = round(result[0], 2)
+ else:
+ result = [round(x, 2) for x in result]
+ return result
+
+
+def get_vad_segments(
+ audio,
+ sample_rate=16000,
+ min_speech_duration=0.1,
+ min_silence_duration=0.1,
+ method="auditok",
+):
+ """
+ Get speech segments from audio using the method VAD
+ parameters:
+ audio: torch.Tensor
+ audio data *in 16kHz*
+ output_sample: bool
+ if True, return start and end in samples instead of seconds
+ min_speech_duration: float
+ minimum duration (in sec) of a speech segment
+ min_silence_duration: float
+ minimum duration (in sec) of a silence segment
+ dilatation: float
+ how much (in sec) to enlarge each speech segment detected by the VAD
+ method: str or list
+ VAD method to use (auditok, silero, silero:v3.1)
+ """
+ global _silero_vad_model, _silero_get_speech_ts, _has_onnx
+ if isinstance(method, list):
+ # Explicit timestamps
+ segments = [
+ {"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method
+ ]
+ elif isinstance(method, str) and method.startswith("silero"):
+ version = None
+ _, version = check_vad_method(method, True)
+ # See discussion https://github.com/linto-ai/whisper-timestamped/pull/142/files#r1398326287
+ need_folder_hack = version and (version < "v4")
+
+ if _silero_vad_model.get(version) is None:
+ # ONNX support since 3.1 in silero
+ if (version is None or version >= "v3.1") and (_has_onnx is not False):
+ onnx = True
+ try:
+ import onnxruntime
+
+ onnxruntime.set_default_logger_severity(
+ 3
+ ) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model."
+ _has_onnx = True
+ except ImportError as err:
+ logger.warning(
+ f"Please install onnxruntime to use more efficiently silero VAD"
+ )
+ _has_onnx = False
+ onnx = False
+ else:
+ onnx = False
+
+ # Choose silero version because of problems with version 4, see https://github.com/linto-ai/whisper-timestamped/issues/74
+ torch_home = os.environ.get("TORCH_HOME", "~/.cache/torch")
+ repo_or_dir_master = os.path.expanduser(
+ torch_home + "/hub/snakers4_silero-vad_master"
+ )
+ repo_or_dir_specific = (
+ os.path.expanduser(torch_home + f"/hub/snakers4_silero-vad_{version}")
+ if version
+ else repo_or_dir_master
+ )
+ repo_or_dir = repo_or_dir_specific
+ tmp_folder = None
+
+ def apply_folder_hack():
+ nonlocal tmp_folder
+ if os.path.exists(repo_or_dir_master):
+ tmp_folder = repo_or_dir_master + ".tmp"
+ shutil.move(repo_or_dir_master, tmp_folder)
+ # Make a symlink to the v3.1 model, otherwise it fails
+ input_exists = os.path.exists(repo_or_dir_specific)
+ if not input_exists:
+ # Make dummy file for the symlink to work
+ os.makedirs(repo_or_dir_specific, exist_ok=True)
+ os.symlink(repo_or_dir_specific, repo_or_dir_master)
+ if not input_exists:
+ shutil.rmtree(repo_or_dir_specific)
+
+ source = "local"
+ if not os.path.exists(repo_or_dir):
+ # Load specific version of silero
+ repo_or_dir = (
+ f"snakers4/silero-vad:{version}"
+ if version
+ else "snakers4/silero-vad"
+ )
+ source = "github"
+ if need_folder_hack:
+ apply_folder_hack()
+ try:
+ from torch.hub import load as torch_load
+ silero_vad_model, utils = torch_load(
+ repo_or_dir=repo_or_dir,
+ model="silero_vad",
+ onnx=onnx,
+ source=source,
+ )
+ _silero_vad_model[version] = silero_vad_model
+ except ImportError as err:
+ raise RuntimeError(
+ f"Please install what is needed to use the silero VAD (or use another VAD method)"
+ ) from err
+ except Exception as err:
+ raise RuntimeError(
+ f"Problem when installing silero with version {version}. Check versions here: https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models"
+ ) from err
+ finally:
+ if need_folder_hack:
+ if os.path.exists(repo_or_dir_master):
+ os.remove(repo_or_dir_master)
+ if tmp_folder:
+ shutil.move(tmp_folder, repo_or_dir_master)
+ assert os.path.isdir(
+ repo_or_dir_specific
+ ), f"Unexpected situation: missing {repo_or_dir_specific}"
+
+ _silero_get_speech_ts = utils[0]
+
+ # Cheap normalization of the volume
+
+ if isinstance(audio, np.ndarray):
+ audio = audio / max(0.1, np.max(np.abs(audio)))
+ else:
+ audio = audio / max(0.1, audio.abs().max())
+ segments = _silero_get_speech_ts(
+ audio,
+ _silero_vad_model[version],
+ sampling_rate=sample_rate,
+ min_speech_duration_ms=round(min_speech_duration * 1000),
+ min_silence_duration_ms=round(min_silence_duration * 1000),
+ return_seconds=False,
+ )
+
+ elif method == "auditok":
+ # Cheap normalization of the volume
+ if isinstance(audio, np.ndarray):
+ audio = audio / max(0.1, np.max(np.abs(audio)))
+ data = (audio * 32767).astype(np.int16).tobytes()
+ else:
+ audio = audio / max(0.1, audio.abs().max())
+ data = (audio.numpy() * 32767).astype(np.int16).tobytes()
+
+ audio_duration = len(audio) / sample_rate
+ from auditok import split
+ segments = split(
+ data,
+ sampling_rate=sample_rate, # sampling frequency in Hz
+ channels=1, # number of channels
+ sample_width=2, # number of bytes per sample
+ min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds
+ max_dur=audio_duration, # maximum duration of an event
+ max_silence=min(
+ audio_duration * 0.95, min_silence_duration
+ ), # maximum duration of tolerated continuous silence within an event
+ energy_threshold=50,
+ drop_trailing_silence=True,
+ )
+
+ segments = [
+ {"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate}
+ for s in segments
+ ]
+
+ else:
+ raise ValueError(f"Got unexpected VAD method {method}")
+ return segments
+
+
+def apply_dilatation(segments, dilatation, sample_rate, audio, output_sample=False):
+ if dilatation > 0:
+ dilatation = round(dilatation * sample_rate)
+ new_segments = []
+ for seg in segments:
+ new_seg = {
+ "start": max(0, seg["start"] - dilatation),
+ "end": min(len(audio), seg["end"] + dilatation),
+ }
+ if len(new_segments) > 0 and new_segments[-1]["end"] >= new_seg["start"]:
+ new_segments[-1]["end"] = new_seg["end"]
+ else:
+ new_segments.append(new_seg)
+ segments = new_segments
+
+ ratio = 1 if output_sample else 1 / sample_rate
+
+ if ratio != 1:
+ for seg in segments:
+ seg["start"] *= ratio
+ seg["end"] *= ratio
+ if output_sample:
+ for seg in segments:
+ seg["start"] = round(seg["start"])
+ seg["end"] = round(seg["end"])
+ return segments
+
+def check_vad_method(method, with_version=False):
+ """
+ Check whether the VAD method is valid and return the method in a consistent format
+
+ method: str or list or True or False
+ """
+ if method in [True, "True", "true"]:
+ return check_vad_method("silero") # default method
+ elif method in [None, False, "False", "false", "None", "none"]:
+ return None
+ elif not isinstance(method, str) and hasattr(method, "__iter__"):
+ # list of explicit timestamps
+ checked_pairs = []
+ for s_e in method:
+ assert (
+ len(s_e) == 2
+ ), f"Got unexpected element {s_e} in the list of VAD segments. Expect (start, end) pairs"
+ checked_pairs.append(tuple(s_e))
+ return checked_pairs
+ elif isinstance(method, str) and method.startswith("silero"):
+ version = None
+ if method != "silero":
+ assert method.startswith("silero:"), f"Got unexpected VAD method {method}"
+ version = method.split(":")[1]
+ if not version.startswith("v"):
+ version = "v" + version
+ try:
+ assert float(version[1:]) >= 1
+ except:
+ raise ValueError(
+ f"Got unexpected silero version {version} (please check https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)"
+ )
+ if with_version:
+ return ("silero", version)
+ else:
+ return method
+ elif method == "auditok":
+ try:
+ import auditok
+ except ImportError:
+ raise ImportError(
+ "Please install auditok to use the auditok VAD (or use another VAD method)"
+ )
+ else:
+ try:
+ method = eval(method)
+ assert hasattr(method, "__iter__")
+ except:
+ raise ValueError(f"Got unexpected VAD method {method}")
+ return check_vad_method(method, with_version=with_version)
+ return method