diff --git a/.gitignore b/.gitignore index 06b349b..725405d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ -start_container.sh .env* -test/* tmp* +test.log __pycache__ \ No newline at end of file diff --git a/README.md b/README.md index 10f860f..492dd8c 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,7 @@ LinTO-STT can either be used as a standalone transcription service or deployed w The following families of STT models are currently supported (please refer to respective documentation for more details): * [Kaldi models](kaldi/README.md) * [Whisper models](whisper/README.md) +* [Test scripts](test/README.md) LinTO-STT can either be used as a standalone transcription service or deployed within a micro-services infrastructure using a message broker connector. diff --git a/http_server/ingress.py b/http_server/ingress.py index ec21d33..424e71c 100644 --- a/http_server/ingress.py +++ b/http_server/ingress.py @@ -9,7 +9,7 @@ from flask import Flask, json, request from serving import GeventServing, GunicornServing from stt import logger as stt_logger -from stt.processing import MODEL, USE_GPU, decode, load_wave_buffer +from stt.processing import MODEL, USE_GPU, decode, load_wave_buffer, warmup from swagger import setupSwaggerUI app = Flask("__stt-standalone-worker__") @@ -24,7 +24,7 @@ logger.setLevel(logging.INFO) # If websocket streaming route is enabled -if os.environ.get("ENABLE_STREAMING", False) in [True, "true", 1]: +if os.environ.get("ENABLE_STREAMING", "false").lower() in ["true", "1"]: from flask_sock import Sock from stt.processing.streaming import ws_streaming @@ -84,7 +84,9 @@ def transcribe(): logger.error(traceback.format_exc()) logger.error(repr(error)) - return "Server Error: {}".format(str(error)), 400 if isinstance(error, ValueError) else 500 + return "Server Error: {}".format(str(error)), ( + 400 if isinstance(error, ValueError) else 500 + ) @app.errorhandler(405) @@ -128,12 +130,18 @@ def server_error(error): serving_type = GunicornServing logger.debug("Serving with gunicorn") + def post_worker_init(worker): + logger.info(f"Worker {worker.pid} init") + warmup() + logger.info(f"Worker {worker.pid} fully initialized") + serving = serving_type( app, { "bind": f"0.0.0.0:{args.service_port}", "workers": args.workers, "timeout": 3600 * 24, + "post_worker_init": post_worker_init, }, ) logger.info(args) diff --git a/kaldi/.envdefault b/kaldi/.envdefault index 33a394c..f22fa40 100644 --- a/kaldi/.envdefault +++ b/kaldi/.envdefault @@ -7,8 +7,8 @@ ENABLE_STREAMING=true # TASK PARAMETERS SERVICE_NAME=stt -SERVICES_BROKER=redis://192.168.0.1:6379 -BROKER_PASS=password +SERVICES_BROKER=redis://172.17.0.1:6379 +BROKER_PASS= # WEBSOCKET PARAMETERS STREAMING_PORT=80 diff --git a/kaldi/README.md b/kaldi/README.md index 7ebfa85..e7c2036 100644 --- a/kaldi/README.md +++ b/kaldi/README.md @@ -68,7 +68,7 @@ cp kaldi/.envdefault kaldi/.env STT can be used three ways: * Through an [HTTP API](#http-server) using the **http**'s mode. -* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode. +* Through a [message broker](#celery-task) using the **task**'s mode. * Through a [websocket server](#websocket-server) **websocket**'s mode. Mode is specified using the .env value or environment variable ```SERVING_MODE```. @@ -99,7 +99,7 @@ This will run a container providing an [HTTP API](#http-api) binded on the host | LM_PATH | Path to the language model on the host machine mounted to /opt/LM | /my/path/to/models/fr-FR_big-v2.2.0 | | MODEL_PATH | Path to the model (using MODEL_TYPE=vosk) mounted to /opt/model | /my/path/to/models/vosk-model | -### Micro-service within LinTO-Platform stack +### Celery task The TASK serving mode connect a celery worker to a message broker. The SERVICE_MODE value in the .env should be set to ```task```. @@ -205,7 +205,10 @@ On a successfull transcription the returned object is a json object structured a * The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False) -## Test +## Tests + +See [Test scripts](../test/README.md) for more details about testing. + ### Curl You can test you http API using curl: ```bash diff --git a/kaldi/RELEASE.md b/kaldi/RELEASE.md index 6ce152a..f5bc967 100644 --- a/kaldi/RELEASE.md +++ b/kaldi/RELEASE.md @@ -1,3 +1,6 @@ +# 1.0.2 +- Fix task mode for kaldi by updating SERVICES_BROKER and BROKER_PASS in .envdefault + # 1.0.1 - Fix streaming mode (websocket) in linto-stt-kaldi diff --git a/kaldi/docker-entrypoint.sh b/kaldi/docker-entrypoint.sh index 212b145..74d3b15 100755 --- a/kaldi/docker-entrypoint.sh +++ b/kaldi/docker-entrypoint.sh @@ -25,7 +25,7 @@ fi # Launch parameters, environement variables and dependencies check if [ -z "$SERVICE_MODE" ] then - echo "ERROR: Must specify a serving mode: [ http | task | websocket ]" + echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (None was specified)" exit -1 else if [ "$SERVICE_MODE" = "http" ] @@ -48,7 +48,7 @@ else echo "Running Websocket server on port ${STREAMING_PORT:=80}" python websocket/websocketserver.py else - echo "ERROR: Wrong serving command: $1" + echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (got SERVICE_MODE=$SERVICE_MODE)" exit -1 fi fi diff --git a/kaldi/stt/processing/__init__.py b/kaldi/stt/processing/__init__.py index 9f99406..e0476a1 100644 --- a/kaldi/stt/processing/__init__.py +++ b/kaldi/stt/processing/__init__.py @@ -29,5 +29,8 @@ sys.exit(-1) logger.info("Acoustic model and decoding graph loaded. (t={}s)".format(time() - start)) +def warmup(): + pass + # Not implemented yet in Kaldi USE_GPU = False diff --git a/test/README.md b/test/README.md new file mode 100644 index 0000000..1739688 --- /dev/null +++ b/test/README.md @@ -0,0 +1,70 @@ +# LinTO-STT-Tests + +## Use tests + +### HTTP - transcribe + +You can test your http server by using: + +```bash +test_deployment.sh +``` + +> ⚠️ Be sure to check that you use the right port (default port for testing: 8080). + +### HTTP - streaming + +You can test your http streaming route by using: +```bash +test_streaming.py +``` +Be sure to have a working microphone. +> ⚠️ Be sure to check that you use the right port (default port for testing: 8080). + +If you want to test the streaming on a file: +```bash +test_streaming.py --audio_file bonjour.wav +``` + +### Task + +You can test your deployment of the task service mode by using: + +```bash +test_celery.py AUDIO.wav +``` + +with AUDIO.wav the file you want to test on, for example, you can use bonjour.wav. + +> ⚠️ Be sure to check that you use the same port in your .env and in test_celery.py (default port for testing: 6379) + + +## Unit tests + +You will need to install: +```bash +pip3 install ddt +``` + +To test the Kaldi models, you will need to download the models (see [Kaldi models](../kaldi/README.md)) and then fill the AM_PATH and LM_PATH fields in the [test_config.ini file](test_config.ini). +> ⚠️ If you don't specify the models, the tests about Kaldi will fail. + +To launch the test you can do : +```bash +python test/test.py +``` + +> ⚠️ Be sure to launch it from the root folder of the repository. + +If you want the test to stop at the first fail use the -f flag: +```bash +python test/test.py -f +``` +If you want to run a subset of test you can use -k with a part of a test name. for example only kaldi tests: +```bash +python test/test.py -k kaldi +``` +or test with VAD=auditok, DEVICE=cuda: +```bash +python test/test.py -k VAD_auditok_DEVICE_cuda +``` \ No newline at end of file diff --git a/test/test.py b/test/test.py new file mode 100644 index 0000000..f683583 --- /dev/null +++ b/test/test.py @@ -0,0 +1,318 @@ +import unittest +import os +import time +import subprocess +import requests +import re +from ddt import ddt, idata +from pathlib import Path +import warnings + +TESTDIR = os.path.dirname(os.path.realpath(__file__)) +ROOTDIR = os.path.dirname(TESTDIR) +os.chdir(ROOTDIR) +TESTDIR = os.path.basename(TESTDIR) + + + +def generate_whisper_test_setups(): + dockerfiles = [ + "whisper/Dockerfile.ctranslate2", + "whisper/Dockerfile.ctranslate2.cpu", + "whisper/Dockerfile.torch", + "whisper/Dockerfile.torch.cpu", + ] + + servings = ["http", "task"] + + vads = [None, "false", "auditok", "silero"] + devices = [None, "cpu", "cuda"] + models = ["tiny"] + + for dockerfile in dockerfiles: + for device in devices: + for vad in vads: + for model in models: + for serving in servings: + + # Test CPU dockerfile only on CPU + if dockerfile.endswith("cpu") and device != "cpu": + continue + + # Do not test all VAD settings if not on CPU + if vad not in [None, "silero"]: + if device != "cpu": + continue + + env_variables = "" + if vad: + env_variables += f"VAD={vad} " + if device: + env_variables += f"DEVICE={device} " + env_variables += f"MODEL={model}" + + yield dockerfile, serving, env_variables + +def generate_kaldi_test_setups(): + dockerfiles = ["kaldi/Dockerfile"] + + servings = ["http", "task"] + + for dockerfile in dockerfiles: + for serving in servings: + env_variables = "" + yield dockerfile, serving, env_variables + +def copy_env_file(env_file, env_variables=""): + env_variables = env_variables.split() + env_variables.append("SERVICE_MODE=") + with open(env_file, "r") as f: + lines = f.readlines() + with open(f"{TESTDIR}/.env", "w") as f: + for line in lines: + if not any([line.startswith(b.split("=")[0] + "=") for b in env_variables]): + f.write(line) + +@ddt +class TestRunner(unittest.TestCase): + + built_images = [] + redis_launched = False + + # def __init__(self, *args, **kwargs): + # super(TestRunner, self).__init__(*args, **kwargs) + # self.cleanup() + + def echo_success(self, message): + print('\033[0;32m' + u'\u2714' + '\033[0m ' + message) + + def echo_failure(self, message): + print('\033[0;31m' + u'\u2716' + '\033[0m ' + message) + + def echo_note(self, message): + print(u'\u231B' + ' ' + message) + + def echo_command(self, message): + print(f"$ {message}") + + def report_failure(self, message, expect_failure=False): + if not expect_failure: + self.echo_failure(message) + self.cleanup() + if not expect_failure: + self.fail(message) + return message + + def report_success(self): + self.echo_success("Test passed.") + self.cleanup() + + def cleanup(self): + # Check if the container is running + p = subprocess.Popen(["docker", "ps", "-a"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + out, err = p.communicate() + if b"test_container" in out: + self.echo_command("docker stop test_container") + subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + time.sleep(0.2) # Without this, the following tests can fail (The container name "/test_container" is already in use) + + def process_output(self, p): + l = p.communicate()[0].decode('utf-8').replace('\n', '\n\t') + e = p.communicate()[1].decode('utf-8').replace('\n', '\n\t') + return f" \u2192 Log Message:\n\t{l}\n \u2192 Error Message:\n\t{e}" + + + def check_http_server_availability(self, server, pid): + total_wait_time = SERVER_STARTING_TIMEOUT # 10 minutes in seconds + retry_interval = 1 # Interval between attempts (in seconds) + elapsed_time = 0 + + while elapsed_time < total_wait_time: + try: + response = requests.head(server) + if response.status_code == 200: + self.echo_note(f"Server: {server} is available after {elapsed_time} sec.") + return + except requests.ConnectionError: + pass + if pid.poll() is not None: + return f"The server container has stopped for an unexpected reason.\n{self.process_output(pid)}" + + time.sleep(retry_interval) + elapsed_time += retry_interval + + return f"Server: {server} is not available after {total_wait_time} seconds, server launching must have failed.\n{self.process_output(pid)}" + + def launch_redis(self): + if TestRunner.redis_launched: + return + cmd = "docker run --rm -p 6379:6379 --name test_redis redis/redis-stack-server:latest redis-server /etc/redis-stack.conf --protected-mode no --bind 0.0.0.0 --loglevel debug" + self.echo_command(cmd) + p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + time.sleep(2) + if p.poll() is not None: + self.cleanup() + return f"Redis server failed to start.\n{self.process_output(p)}", None + TestRunner.redis_launched = True + + def build_and_run_container(self, serving, docker_image, env_variables, use_local_cache): + self.echo_note(f"* Docker image: {docker_image}") + self.echo_note(f"* Options.....: {env_variables}") + build_args = "" + for i, env in enumerate(env_variables.split()): + if i>0 and env_variables.split()[i-1] =="-v": + build_args += f"-v {env} " + elif env=="-v": + continue + else: + build_args += f"--env {env} " + build_args += f"--env SERVICE_MODE={serving} " + if use_local_cache: + home = str(Path.home()) + build_args += f"-v {home}/.cache:/root/.cache " + + if serving == "task": + self.launch_redis() + build_args += "-v {}/:/opt/audio ".format(os.getcwd()) + + tag = f"test_{os.path.basename(docker_image)}" + if tag not in TestRunner.built_images: + # Only build images that have not been built yet + cmd = f'docker build . -f {docker_image} -t linto-stt-test:{tag}' + self.echo_command(cmd) + start_time = time.time() + p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + p.wait() + end_time = time.time() + if p.poll() != 0: + self.cleanup() + return f"Docker build failed.\n{self.process_output(p)}", None + self.echo_note(f"Docker image has been successfully built in {end_time - start_time:.0f} sec.") + TestRunner.built_images.append(tag) + + cmd=f"docker run --rm -p 8080:80 --name test_container --env-file {TESTDIR}/.env --gpus all {build_args} linto-stt-test:{tag}" + self.echo_command(cmd) + p = subprocess.Popen(cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if p.poll() is not None: + return f"Docker container failed to start.\n{self.process_output(p)}", None + return None, p + + def transcribe(self, command, regex, test_file, error_message, success_message, timeout=None): + start = time.time() + res = subprocess.run(command, shell=True, timeout=timeout, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + end = time.time() + if res.returncode != 0: + raise FileNotFoundError(f"Error: {res.stderr.decode('utf-8')}") + res = res.stdout.decode('utf-8') + if not re.search(regex, res): + message = f"{error_message}: The string '{res}' is not matching the regex ({regex}), the server didn't transcribe correctly." + return self.report_failure(message) + self.echo_note(f"{success_message} has transcribed {test_file} in {end - start:.0f} sec.") + return + + def run_test(self, docker_image="whisper/Dockerfile.ctranslate2", serving="http", env_variables="", test_file=f"{TESTDIR}/bonjour.wav", use_local_cache=True, expect_failure=False): + warnings.simplefilter("ignore", ResourceWarning) + regex = "" + if os.path.basename(test_file) == "bonjour.wav": + regex = re.compile("[bB]onjour") + r, pid = self.build_and_run_container(serving, docker_image, env_variables, use_local_cache) + if r: + return self.report_failure(r, expect_failure=expect_failure) + if serving == "http": + r=self.check_http_server_availability("http://localhost:8080/healthcheck", pid) + if r: + return self.report_failure(r, expect_failure=expect_failure) + cmd = f'curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@{test_file};type=audio/wav"' + self.echo_command(cmd) + r = self.transcribe(cmd, regex, test_file, "Error transcription", "HTTP route 'transcribe'") + if r: + return self.report_failure(r, expect_failure=expect_failure) + cmd = f"python3 {TESTDIR}/test_streaming.py --audio_file {test_file}" + self.echo_command(cmd) + r = self.transcribe(cmd, regex, test_file, "Error streaming", "HTTP route 'streaming'") + elif serving == "task": + # you can be stuck here if the server crashed bc the task will be in the queue forever + cmd = f"python3 {TESTDIR}/test_celery.py {test_file}" + self.echo_command(cmd) + r = self.transcribe(cmd, regex, test_file, "Error task", "TASK route", timeout=60) + else: + raise RuntimeError(f"Unknown serving mode: {serving}") + if r: + return self.report_failure(r, expect_failure=expect_failure) + if not expect_failure: + self.report_success() + return "" + + def setUp(self): + # Print an empty line because unittest prints the name of the test first, without a newline + print() + print("-"*70) + + def tearDown(self): + print("-"*70) + + @idata(generate_kaldi_test_setups()) + def test_01_kaldi_integration(self, setup): + dockerfile, serving, env_variables = setup + if AM_PATH is None or LM_PATH is None or AM_PATH=="" or LM_PATH=="": + self.fail("AM or LM path not provided. Skipping kaldi test.") + if not os.path.exists(AM_PATH) or not os.path.exists(LM_PATH): + self.fail(f"AM or LM path not found: {AM_PATH} or {LM_PATH}") + copy_env_file("kaldi/.envdefault") + env_variables += f"-v {AM_PATH}:/opt/AM -v {LM_PATH}:/opt/LM" + self.run_test(dockerfile, serving=serving, env_variables=env_variables) + + + @idata(generate_whisper_test_setups()) + def test_03_whisper_integration(self, setup): + dockerfile, serving, env_variables = setup + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(dockerfile, serving=serving, env_variables=env_variables) + + def test_02_whisper_failures_cuda_on_cpu_dockerfile(self): + env_variables = "MODEL=tiny DEVICE=cuda" + dockerfile = "whisper/Dockerfile.ctranslate2.cpu" + copy_env_file("whisper/.envdefault", env_variables) + self.assertIn("cannot open shared object file", self.run_test(dockerfile, env_variables=env_variables, expect_failure=True)) + + def test_02_whisper_failures_not_existing_file(self): + env_variables = "MODEL=tiny" + copy_env_file("whisper/.envdefault", env_variables) + with self.assertRaises(FileNotFoundError): + self.run_test(test_file="notexisting", env_variables=env_variables, expect_failure=True) + self.cleanup() + + def test_02_whisper_failures_wrong_vad(self): + env_variables = "VAD=whatever MODEL=tiny" + copy_env_file("whisper/.envdefault", env_variables) + self.assertIn("Got unexpected VAD method whatever", self.run_test(env_variables=env_variables, expect_failure=True)) + + def test_04_model_whisper(self): + env_variables = "MODEL=small" + copy_env_file("whisper/.envdefault", env_variables) + self.run_test(env_variables=env_variables) + +def finalize_tests(): + subprocess.run(["docker", "stop", "test_container"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + subprocess.run(["docker", "stop", "test_redis"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + +AM_PATH = None +LM_PATH = None +SERVER_STARTING_TIMEOUT = 60 + +if __name__ == '__main__': + from configparser import ConfigParser + config = ConfigParser() + + config.read(f"{TESTDIR}/test_config.ini") + + SERVER_STARTING_TIMEOUT = int(config.get('server', 'STARTING_TIMEOUT')) if config.get('server', 'STARTING_TIMEOUT')!="" else SERVER_STARTING_TIMEOUT + + AM_PATH = config.get('kaldi', 'AM_PATH') + LM_PATH = config.get('kaldi', 'LM_PATH') + + try: + unittest.main(verbosity=2) + finally: + finalize_tests() diff --git a/test/test_celery.py b/test/test_celery.py new file mode 100755 index 0000000..59ed62e --- /dev/null +++ b/test/test_celery.py @@ -0,0 +1,16 @@ +import sys +from celery import Celery + +def transcribe_task(file_path): + celery = Celery(broker='redis://localhost:6379/0', backend='redis://localhost:6379/1') + r = celery.send_task( + 'transcribe_task', + ( + file_path, + True, + ), + queue='stt') + return r.get() + +if __name__ == '__main__': + print(transcribe_task(sys.argv[1])) \ No newline at end of file diff --git a/test/test_config.ini b/test/test_config.ini new file mode 100644 index 0000000..76bf72d --- /dev/null +++ b/test/test_config.ini @@ -0,0 +1,6 @@ +[server] +STARTING_TIMEOUT=60 + +[kaldi] +AM_PATH= +LM_PATH= \ No newline at end of file diff --git a/test/test_deployment.sh b/test/test_deployment.sh index b1b8d36..84daac8 100755 --- a/test/test_deployment.sh +++ b/test/test_deployment.sh @@ -1 +1 @@ -curl -X POST "http://localhost:8888/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@bonjour.wav;type=audio/wav" +curl -X POST "http://localhost:8080/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@bonjour.wav;type=audio/wav" diff --git a/test/test_streaming.py b/test/test_streaming.py new file mode 100644 index 0000000..d78f6cf --- /dev/null +++ b/test/test_streaming.py @@ -0,0 +1,120 @@ +import asyncio +import websockets +import json +import shutil + +def linstt_streaming(*kargs, **kwargs): + text = asyncio.run(_linstt_streaming(*kargs, **kwargs)) + return text + +async def _linstt_streaming( + audio_file, + ws_api = "ws://localhost:8080/streaming", + verbose = False, +): + + if audio_file is None: + import pyaudio + # Init pyaudio + audio = pyaudio.PyAudio() + stream = audio.open(format=pyaudio.paInt16, channels=1, rate=16000, input=True, frames_per_buffer=2048) + if verbose > 1: + print("Start recording") + else: + stream = open(audio_file, "rb") + + alive = True + text = "" + partial = None + + try: + async with websockets.connect(ws_api) as websocket: + await websocket.send(json.dumps({"config" : {"sample_rate": 16000 }})) + while alive: + try: + data = stream.read(32000) + if audio_file and not data: + if verbose > 1: + print("\nAudio file finished") + alive = False + await websocket.send(data) + res = await websocket.recv() + message = json.loads(res) + if message is None: + if verbose > 1: + print("\n Received None") + continue + if "partial" in message.keys(): + partial = message["partial"] + if verbose: + print_partial(partial) + elif "text" in message.keys(): + line = message["text"] + if verbose: + print_final(line) + if line: + if text: + text += "\n" + text += line + elif verbose: + print("???", message) + except KeyboardInterrupt: + if verbose > 1: + print("\nKeyboard interrupt") + alive = False + await websocket.send(json.dumps({"eof" : 1})) + res = await websocket.recv() + message = json.loads(res) + if isinstance(message, str): + message = json.loads(message) + if text: + text += " " + text += message["text"] + try: + res = await websocket.recv() + except websockets.ConnectionClosedOK: + if verbose > 1: + print("Websocket Closed") + except KeyboardInterrupt: + if verbose > 1: + print("\nKeyboard interrupt") + if verbose: + print_final("= FULL TRANSCRIPTION ", background="=") + print(text) + + return text + +def print_partial(text): + text = text + "…" + terminal_size = shutil.get_terminal_size() + width = terminal_size.columns + start = ((len(text) - 1)// width) * width + if start > 0: + print(" "*width, end="\r") + if start < len(text) - 1: + print("…"+text[start+1:]+" "*(width-len(text)-start-1), end="\r") + else: + print(text[-width:], end="\r") + else: + print(text, end="\r") + +def print_final(text, background=" "): + terminal_size = shutil.get_terminal_size() + width = terminal_size.columns + print(background * width, end="\r") + print(text) + +if __name__ == "__main__": + + import argparse + parser = argparse.ArgumentParser(description='Transcribe input streaming (from mic or a file) with LinSTT', + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument('--server', help='Transcription server', + default="ws://localhost:8080/streaming", + ) + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose mode") + parser.add_argument("--audio_file", default=None, help="A path to an audio file to transcribe (if not provided, use mic)") + args = parser.parse_args() + + res = linstt_streaming(args.audio_file, args.server, verbose=2 if args.verbose else 1) \ No newline at end of file diff --git a/whisper/.envdefault b/whisper/.envdefault index 75919c0..a8f8794 100644 --- a/whisper/.envdefault +++ b/whisper/.envdefault @@ -1,7 +1,7 @@ ############################################ # SERVING PARAMETERS ############################################ -# "http" or "task" +# "http" or "task" or "websocket" SERVICE_MODE=http # Below: used when SERVICE_MODE=task @@ -9,6 +9,12 @@ SERVICE_NAME=stt SERVICES_BROKER=redis://172.17.0.1:6379 BROKER_PASS= +# HTTP PARAMETERS +ENABLE_STREAMING=true + +# WEBSOCKET PARAMETERS +STREAMING_PORT=80 + ############################################ # STT MODELING PARAMETERS ############################################ @@ -30,6 +36,15 @@ PROMPT= # This option is experimental (and not implemented with ctranslate2). # ALIGNMENT_MODEL=wav2vec +# Voice Activity Detection (VAD) method +# It can be either "0"/"false" (no VAD), "silero", or "1"/"true"/"auditok" (by default) +# VAD=auditok + +# Voice Activity Detection (VAD) parameters +# VAD_DILATATION=0.1 +# VAD_MIN_SPEECH_DURATION=0.1 +# VAD_MIN_SILENCE_DURATION=0.1 + ############################################ # EFFICIENCY PARAMETERS ############################################ @@ -40,7 +55,7 @@ PROMPT= # CUDA_VISIBLE_DEVICES=0 # Number of threads per worker when running on CPU -OMP_NUM_THREADS=4 +NUM_THREADS=4 -# Number of workers +# Number of workers minus one (all except from the main one) CONCURRENCY=2 diff --git a/whisper/Dockerfile.ctranslate2 b/whisper/Dockerfile.ctranslate2 index c2b3cd5..5fd3c53 100644 --- a/whisper/Dockerfile.ctranslate2 +++ b/whisper/Dockerfile.ctranslate2 @@ -15,6 +15,7 @@ COPY websocket /usr/src/app/websocket COPY document /usr/src/app/document COPY whisper/stt /usr/src/app/stt COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +COPY test/bonjour.wav /usr/src/app/test/bonjour.wav ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" diff --git a/whisper/Dockerfile.ctranslate2.cpu b/whisper/Dockerfile.ctranslate2.cpu index df5eac7..1f0f40c 100644 --- a/whisper/Dockerfile.ctranslate2.cpu +++ b/whisper/Dockerfile.ctranslate2.cpu @@ -6,7 +6,6 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins # Install python dependencies COPY whisper/requirements.ctranslate2.txt ./ RUN pip install --no-cache-dir -r requirements.ctranslate2.txt && rm requirements.ctranslate2.txt - WORKDIR /usr/src/app COPY celery_app /usr/src/app/celery_app @@ -15,6 +14,7 @@ COPY websocket /usr/src/app/websocket COPY document /usr/src/app/document COPY whisper/stt /usr/src/app/stt COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +COPY test/bonjour.wav /usr/src/app/test/bonjour.wav ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" diff --git a/whisper/Dockerfile.torch b/whisper/Dockerfile.torch index 06b22f3..acd0b08 100644 --- a/whisper/Dockerfile.torch +++ b/whisper/Dockerfile.torch @@ -15,6 +15,7 @@ COPY websocket /usr/src/app/websocket COPY document /usr/src/app/document COPY whisper/stt /usr/src/app/stt COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +COPY test/bonjour.wav /usr/src/app/test/bonjour.wav ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" diff --git a/whisper/Dockerfile.torch.cpu b/whisper/Dockerfile.torch.cpu index 17a3fb8..2d45336 100644 --- a/whisper/Dockerfile.torch.cpu +++ b/whisper/Dockerfile.torch.cpu @@ -12,7 +12,6 @@ RUN pip3 install \ # Install python dependencies COPY whisper/requirements.torch.txt ./ RUN pip install --no-cache-dir -r requirements.torch.txt && rm requirements.torch.txt - WORKDIR /usr/src/app COPY celery_app /usr/src/app/celery_app @@ -21,6 +20,7 @@ COPY websocket /usr/src/app/websocket COPY document /usr/src/app/document COPY whisper/stt /usr/src/app/stt COPY whisper/docker-entrypoint.sh wait-for-it.sh healthcheck.sh ./ +COPY test/bonjour.wav /usr/src/app/test/bonjour.wav ENV PYTHONPATH="${PYTHONPATH}:/usr/src/app/stt" diff --git a/whisper/README.md b/whisper/README.md index 41dc46a..52d2122 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -114,17 +114,22 @@ cp whisper/.envdefault whisper/.env | PARAMETER | DESCRIPTION | EXEMPLE | |---|---|---| -| SERVICE_MODE | STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` | -| MODEL | Path to a Whisper model, type of Whisper model used, or HuggingFace identifier of a Whisper model. | `large-v3` \| `distil-whisper/distil-large-v2` \| \ \| ... | -| LANGUAGE | (Optional) Language to recognize | `*` \| `fr` \| `fr-FR` \| `French` \| `en` \| `en-US` \| `English` \| ... | -| PROMPT | (Optional) Prompt to use for the Whisper model | `some free text to encourage a certain transcription style (disfluencies, no punctuation, ...)` | -| ALIGNMENT_MODEL | (Optional and deprecated) Path to the wav2vec model for word alignment, or name of HuggingFace repository or torchaudio pipeline | `WAV2VEC2_ASR_BASE_960H` \| `jonatasgrosman/wav2vec2-large-xlsr-53-english` \| \ \| ... | -| DEVICE | (Optional) Device to use for the model | `cpu` \| `cuda` ... | -| CUDA_VISIBLE_DEVICES | (Optional) GPU device index to use, if several. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | -| CONCURRENCY | Maximum number of parallel requests | `2` | -| SERVICE_NAME | (For the task mode) queue's name for task processing | `my-stt` | -| SERVICE_BROKER | (For the task mode) URL of the message broker | `redis://my-broker:6379` | +| SERVICE_MODE | (Required) STT serving mode see [Serving mode](#serving-mode) | `http` \| `task` | +| MODEL | (Required) Path to a Whisper model, type of Whisper model used, or HuggingFace identifier of a Whisper model. | `large-v3` \| `distil-whisper/distil-large-v2` \| \ \| ... | +| LANGUAGE | Language to recognize | `*` \| `fr` \| `fr-FR` \| `French` \| `en` \| `en-US` \| `English` \| ... | +| PROMPT | Prompt to use for the Whisper model | `some free text to encourage a certain transcription style (disfluencies, no punctuation, ...)` | +| DEVICE | Device to use for the model (by default, GPU/CUDA is used if it is available, CPU otherwise) | `cpu` \| `cuda` | +| NUM_THREADS | Number of threads (maximum) to use for things running on CPU | `1` \| `4` \| ... | +| CUDA_VISIBLE_DEVICES | GPU device index to use, when running on GPU/CUDA. We also recommend to set `CUDA_DEVICE_ORDER=PCI_BUS_ID` on multi-GPU machines | `0` \| `1` \| `2` \| ... | +| CONCURRENCY | Maximum number of parallel requests (number of workers minus one) | `2` | +| VAD | Voice Activity Detection method. Use "false" to disable. If not specified, the default is auditok VAD. | `true` \| `false` \| `1` \| `0` \| `auditok` \| `silero` +| ENABLE_STREAMING | (For the http mode) enable the /streaming websocket route | `true\|false` | +| STREAMING_PORT | (For the websocket mode) the listening port for ingoing WS connexions. | `80` | +| SERVICE_NAME | (For the task mode only) queue's name for task processing | `my-stt` | +| SERVICE_BROKER | (For the task mode only) URL of the message broker | `redis://my-broker:6379` | | BROKER_PASS | (For the task mode only) broker password | `my-password` \| (empty) | +| ALIGNMENT_MODEL | (Deprecated) Path to the wav2vec model for word alignment, or name of HuggingFace repository or torchaudio pipeline | `WAV2VEC2_ASR_BASE_960H` \| `jonatasgrosman/wav2vec2-large-xlsr-53-english` \| \ \| ... | + #### MODEL environment variable @@ -184,7 +189,7 @@ and also `yue(cantonese)` since large-v3. STT can be used in two ways: * Through an [HTTP API](#http-server) using the **http**'s mode. -* Through a [message broker](#micro-service-within-linto-platform-stack) using the **task**'s mode. +* Through a [message broker](#celery-task) using the **task**'s mode. Mode is specified using the .env value or environment variable ```SERVING_MODE```. ```bash @@ -217,11 +222,11 @@ You may also want to add specific options: | Variables | Description | Example | |:-|:-|:-| | `HOST_SERVING_PORT` | Host serving port | 8080 | -| `` | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | +| `` | Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | | `` | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt | -| `` | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | +| `` | Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | -### Micro-service within LinTO-Platform stack +### Celery task The TASK serving mode connect a celery worker to a message broker. The SERVICE_MODE value in the .env should be set to ```task```. @@ -248,10 +253,16 @@ You may also want to add specific options: | Variables | Description | Example | |:-|:-|:-| | `` | Shared audio folder mounted to /opt/audio | /my/path/to/models/vosk-model | -| `` | (Optional) Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | +| `` | Path to a folder to download wav2vec alignment models when relevant | /home/username/.cache | | `` | Path to the Whisper model on the host machine mounted to /opt/model.pt | /my/path/to/models/medium.pt | -| `` | (Optional) Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | +| `` | Path to a folder to a custom wav2vec alignment model | /my/path/to/models/wav2vec | + +### Websocket Server +Websocket server's mode deploy a streaming transcription service only. +The SERVICE_MODE value in the .env should be set to ```websocket```. + +Usage is the same as the [http streaming API](#/streaming). ## Usages ### HTTP API @@ -286,6 +297,18 @@ Return the transcripted text using "text/plain" or a json object when using "app } ``` +#### /streaming +The /streaming route is accessible if the ENABLE_STREAMING environment variable is set to true. + +The route accepts websocket connexions. Exchanges are structured as followed: +1. Client send a json {"config": {"sample_rate":16000}}. +2. Client send audio chunk (go to 3- ) or {"eof" : 1} (go to 5-). +3. Server send either a partial result {"partial" : "this is a "} or a final result {"text": "this is a transcription"}. +4. Back to 2- +5. Server send a final result and close the connexion. + +> Connexion will be closed and the worker will be freed if no chunk are received for 10s. + #### /docs The /docs route offers a OpenAPI/swagger interface. @@ -320,13 +343,24 @@ On a successfull transcription the returned object is a json object structured a * The confidence field contains the overall confidence for the transcription. (0.0 if with_metadata=False) -## Test +## Tests + +See [Test scripts](../test/README.md) for more details about testing. + ### Curl -You can test you http API using curl: +You can test your http API using curl: + ```bash curl -X POST "http://YOUR_SERVICE:YOUR_PORT/transcribe" -H "accept: application/json" -H "Content-Type: multipart/form-data" -F "file=@YOUR_FILE;type=audio/x-wav" ``` +### Streaming +You can test your streaming API using a websocket: + +```bash +python test/test_streaming.py --server ws://YOUR_SERVICE:YOUR_PORT/streaming --audio_file test/bonjour.wav +``` + ## License This project is developped under the AGPLv3 License (see LICENSE). @@ -339,3 +373,4 @@ This project is developped under the AGPLv3 License (see LICENSE). * [HuggingFace Transformers](https://github.com/huggingface/transformers) * [SpeechBrain](https://github.com/speechbrain/speechbrain) * [TorchAudio](https://github.com/pytorch/audio) +* [Whisper_Streaming](https://github.com/ufal/whisper_streaming) \ No newline at end of file diff --git a/whisper/RELEASE.md b/whisper/RELEASE.md index f54537e..84de80f 100644 --- a/whisper/RELEASE.md +++ b/whisper/RELEASE.md @@ -1,3 +1,10 @@ +# 1.0.3 +- Make Voice Activity Detection (VAD) configurable +- Change default VAD from silero (neural approach) to auditok (heuristical approach), because silero can have unpredictable behaviour on different corner cases +- Streaming support +- New NUM_THREADS env variable to control the number of threads +- Load the model when launching the service (not at the first request) + # 1.0.2 - ct2/faster_whisper: Upgrade faster_whisper and support recent distilled models - ct2/faster_whisper: Fix possible gluing of different words together diff --git a/whisper/docker-entrypoint.sh b/whisper/docker-entrypoint.sh index 97a3804..09ea120 100755 --- a/whisper/docker-entrypoint.sh +++ b/whisper/docker-entrypoint.sh @@ -14,7 +14,7 @@ fi # Launch parameters, environement variables and dependencies check if [ -z "$SERVICE_MODE" ] then - echo "ERROR: Must specify a serving mode: [ http | task | websocket ]" + echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (None was specified)" exit -1 else if [ "$SERVICE_MODE" = "http" ] @@ -41,9 +41,12 @@ else /usr/src/app/wait-for-it.sh $(echo $SERVICES_BROKER | cut -d'/' -f 3) --timeout=20 --strict -- echo " $SERVICES_BROKER (Service Broker) is up" || exit 1 echo "RUNNING STT CELERY WORKER" celery --app=celery_app.celeryapp worker $OPT -Ofair --queues=${SERVICE_NAME} -c ${CONCURRENCY} -n ${SERVICE_NAME}_worker@%h - + elif [ "$SERVICE_MODE" == "websocket" ] + then + echo "Running Websocket server on port ${STREAMING_PORT:=80}" + python3 websocket/websocketserver.py else - echo "ERROR: Wrong serving command: $SERVICE_MODE" + echo "ERROR: Must specify an environment variable SERVICE_MODE in [ http | task | websocket ] (got SERVICE_MODE=$SERVICE_MODE)" exit -1 fi fi diff --git a/whisper/requirements.ctranslate2.txt b/whisper/requirements.ctranslate2.txt index e471fd7..87b5e80 100644 --- a/whisper/requirements.ctranslate2.txt +++ b/whisper/requirements.ctranslate2.txt @@ -11,6 +11,7 @@ regex requests>=2.26.0 wavio>=0.0.4 websockets +auditok #faster_whisper==1.0.1 # This is version faster_whisper==1.0.1 + option for (persistent) prompt + fix for large-v3 git+https://github.com/linto-ai/faster-whisper.git \ No newline at end of file diff --git a/whisper/requirements.torch.txt b/whisper/requirements.torch.txt index 3976414..e5f5f93 100644 --- a/whisper/requirements.torch.txt +++ b/whisper/requirements.torch.txt @@ -15,4 +15,5 @@ wavio>=0.0.4 websockets whisper-timestamped onnxruntime -torchaudio \ No newline at end of file +torchaudio +auditok \ No newline at end of file diff --git a/whisper/stt/__init__.py b/whisper/stt/__init__.py index f5551af..8bac458 100644 --- a/whisper/stt/__init__.py +++ b/whisper/stt/__init__.py @@ -12,6 +12,20 @@ # see https://github.com/guillaumekln/faster-whisper/issues/150 os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # GPU in the right order +if os.environ.get("VAD","auditok").lower() in ["true", "1"]: + VAD = "auditok" +elif os.environ.get("VAD","auditok").lower() in ["false", "0"]: + VAD = False +else: + VAD = os.environ.get("VAD","auditok") + +VAD_DILATATION = float(os.environ.get("VAD_DILATATION", 0.5)) +VAD_MIN_SPEECH_DURATION = float(os.environ.get("VAD_MIN_SPEECH_DURATION", 0.1)) +VAD_MIN_SILENCE_DURATION = float(os.environ.get("VAD_MAX_SILENCE_DURATION", 0.1)) + +NUM_THREADS = os.environ.get("NUM_THREADS", os.environ.get("OMP_NUM_THREADS")) +NUM_THREADS = int(NUM_THREADS) + try: import faster_whisper @@ -36,3 +50,21 @@ USE_TORCHAUDIO = True except ImportError: USE_TORCHAUDIO = False + +if USE_CTRANSLATE2: + def set_num_threads(n): + # os.environ["OMP_NUM_THREADS"] = str(n) + pass +else: + import torch + DEFAULT_NUM_THREADS = torch.get_num_threads() + def set_num_threads(n): + torch.set_num_threads(n) + +# Number of CPU threads +if NUM_THREADS is None: + NUM_THREADS = DEFAULT_NUM_THREADS +if NUM_THREADS is not None: + NUM_THREADS = int(NUM_THREADS) +# For Torch, we will set it afterward, because setting that before loading the model can hang the process (see https://github.com/pytorch/pytorch/issues/58962) +set_num_threads(1) diff --git a/whisper/stt/processing/__init__.py b/whisper/stt/processing/__init__.py index b0e7f6d..b82a482 100644 --- a/whisper/stt/processing/__init__.py +++ b/whisper/stt/processing/__init__.py @@ -2,7 +2,7 @@ import os from lockfile import FileLock -from stt import USE_CTRANSLATE2, logger +from stt import USE_CTRANSLATE2, VAD, logger, set_num_threads, NUM_THREADS from .alignment_model import get_alignment_model, load_alignment_model from .decoding import decode @@ -18,12 +18,19 @@ "USE_GPU", ] - +def warmup(): + model.check_loaded() + audio_data = load_audiofile("test/bonjour.wav") + transcription = decode(audio_data, MODEL, False) + logger.info(f"Warmup result: {transcription}") + class LazyLoadedModel: - def __init__(self, model_type, device): + def __init__(self, model_type, device, num_threads): self.model_type = model_type self.device = device + self.num_threads = num_threads self._model = None + self.has_set_num_threads = False def check_loaded(self): if self._model is None: @@ -31,12 +38,19 @@ def check_loaded(self): with FileLock(lockfile): self._model = load_whisper_model(self.model_type, device=self.device) + def check_num_threads(self): + if not self.has_set_num_threads and self.num_threads: + set_num_threads(self.num_threads) + self.has_set_num_threads = True + def __getattr__(self, name): self.check_loaded() + self.check_num_threads() return getattr(self._model, name) def __call__(self, *args, **kwargs): self.check_loaded() + self.check_num_threads() return self._model(*args, **kwargs) @@ -51,16 +65,8 @@ def __call__(self, *args, **kwargs): language = get_language() logger.info(f"Using language {language}") -# Load ASR model -model_type = os.environ.get("MODEL", "medium") -logger.info( - f"Loading Whisper model {model_type} ({'local' if os.path.exists(model_type) else 'remote'})..." -) -try: - model = LazyLoadedModel(model_type, device=device) - # model = load_whisper_model(model_type, device=device) -except Exception as err: - raise Exception("Failed to load transcription model: {}".format(str(err))) from err +logger.info(f"VAD={VAD}") +logger.info(f"USE_CTRANSLATE2={USE_CTRANSLATE2}") # Load alignment model (if any) alignment_model = get_alignment_model(os.environ.get("alignment_model"), language) @@ -77,4 +83,19 @@ def __call__(self, *args, **kwargs): ) alignment_model = {} # Alignement model(s) will be loaded on the fly -MODEL = (model, alignment_model) + +# Load ASR model +model_type = os.environ.get("MODEL", "medium") +logger.info( + f"Loading Whisper model {model_type} ({'local' if os.path.exists(model_type) else 'remote'})..." +) +try: + model = LazyLoadedModel(model_type, device=device, num_threads=NUM_THREADS) + MODEL = (model, alignment_model) + if USE_GPU: + warmup() +except Exception as err: + raise Exception("Failed to load transcription model: {}".format(str(err))) from err + + + \ No newline at end of file diff --git a/whisper/stt/processing/decoding.py b/whisper/stt/processing/decoding.py index 5f032e4..9ead692 100644 --- a/whisper/stt/processing/decoding.py +++ b/whisper/stt/processing/decoding.py @@ -5,8 +5,9 @@ from typing import Tuple, Union import numpy as np -from stt import USE_CTRANSLATE2, logger +from stt import USE_CTRANSLATE2, VAD, VAD_DILATATION, VAD_MIN_SILENCE_DURATION, VAD_MIN_SPEECH_DURATION, logger +from .vad import remove_non_speech from .alignment_model import get_alignment_model, load_alignment_model from .text_normalize import normalize_text, remove_emoji, remove_punctuation from .utils import SAMPLE_RATE, get_language @@ -17,7 +18,6 @@ import whisper_timestamped USE_ACCURATE = True -USE_VAD = True if USE_ACCURATE: default_beam_size = 5 @@ -47,7 +47,6 @@ def decode( ) -> dict: if language is None: language = get_language() - kwargs = copy.copy(locals()) kwargs.pop("model_and_alignementmodel") kwargs["model"], kwargs["alignment_model"] = model_and_alignementmodel @@ -64,7 +63,6 @@ def decode( kwargs.pop("alignment_model") res = decode_ct2(**kwargs) else: - print("OK") res = decode_torch(**kwargs) logger.info("Transcription complete (t={}s)".format(time.time() - start_t)) @@ -73,24 +71,30 @@ def decode( def decode_ct2( - audio, model, with_word_timestamps, language, remove_punctuation_from_words, **kwargs + audio, + model, + with_word_timestamps, + language, + remove_punctuation_from_words, + **kwargs, ): kwargs["no_speech_threshold"] = 1 # To avoid empty output if kwargs.get("beam_size") is None: kwargs["beam_size"] = 1 if kwargs.get("best_of") is None: kwargs["best_of"] = 1 - + if VAD: + _, speech_segments, _ = remove_non_speech(audio, use_sample=True, method=VAD, dilatation=VAD_DILATATION, \ + min_silence_duration=VAD_MIN_SILENCE_DURATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, return_format="dict") segments, info = model.transcribe( audio, word_timestamps=with_word_timestamps, language=language, # Careful with the following options max_initial_timestamp=10000.0, - vad_filter=USE_VAD, + vad_filter=speech_segments if VAD else False, **kwargs, ) - segments = list(segments) return format_faster_whisper_response( @@ -118,6 +122,10 @@ def decode_torch( fp16 = model.device != torch.device("cpu") + if VAD: + _, speech_segments, _ = remove_non_speech(audio, use_sample=True, method=VAD, dilatation=VAD_DILATATION, \ + min_silence_duration=VAD_MIN_SILENCE_DURATION, min_speech_duration=VAD_MIN_SPEECH_DURATION,) + kwargs = dict( language=language, fp16=fp16, @@ -127,13 +135,15 @@ def decode_torch( condition_on_previous_text=condition_on_previous_text, no_speech_threshold=no_speech_threshold, compression_ratio_threshold=compression_ratio_threshold, - vad=USE_VAD, + vad=speech_segments if VAD else False, initial_prompt=prompt, ) if alignment_model is None: # Use Whisper cross-attention weights - whisper_res = whisper_timestamped.transcribe(model, audio, verbose=None, **kwargs) + whisper_res = whisper_timestamped.transcribe( + model, audio, verbose=None, **kwargs + ) if language is None: language = whisper_res["language"] logger.info(f"Detected language: {language}") @@ -175,7 +185,9 @@ def decode_torch( result["text"] = text result["language"] = language result["confidence-score"] = ( - np.exp(np.array([r["avg_logprob"] for r in segments])).mean() if len(segments) else 0.0 + np.exp(np.array([r["avg_logprob"] for r in segments])).mean() + if len(segments) + else 0.0 ) if not with_word_timestamps: @@ -251,7 +263,9 @@ def decode_torch( return result -def format_whisper_timestamped_response(transcription, remove_punctuation_from_words=False): +def format_whisper_timestamped_response( + transcription, remove_punctuation_from_words=False +): """Format Whisper response.""" for i, seg in enumerate(transcription["segments"][:-1]): @@ -281,9 +295,11 @@ def format_whisper_timestamped_response(transcription, remove_punctuation_from_w return { "text": transcription["text"].strip(), "language": transcription["language"], - "confidence-score": round(np.exp(np.array([r["avg_logprob"] for r in segments])).mean(), 2) - if len(segments) - else 0.0, + "confidence-score": ( + round(np.exp(np.array([r["avg_logprob"] for r in segments])).mean(), 2) + if len(segments) + else 0.0 + ), "words": words, } @@ -307,7 +323,10 @@ def checked_timestamps(start, end=None): if end == start: pass # end = start + 0.01 else: - print("WARNING, end timestamp %f is smaller than start timestamp %f" % (end, start)) + print( + "WARNING, end timestamp %f is smaller than start timestamp %f" + % (end, start) + ) if end is None: return start return (start, end) @@ -327,7 +346,11 @@ def checked_timestamps(start, end=None): and len(words) and len(word_strip) > 1 and word_strip[0] in glue_punctuations - and (word_strip == word_string or not contains_alphanum(words[-1]["text"]) or not contains_alphanum(word_strip)) + and ( + word_strip == word_string + or not contains_alphanum(words[-1]["text"]) + or not contains_alphanum(word_strip) + ) ): words[-1]["text"] += word_strip words[-1]["confidence"].append(word.probability) @@ -368,5 +391,6 @@ def checked_timestamps(start, end=None): transcription, remove_punctuation_from_words=remove_punctuation_from_words ) + def contains_alphanum(text: str) -> bool: - return re.search(r"[^\W\'\-_]", text) \ No newline at end of file + return re.search(r"[^\W\'\-_]", text) diff --git a/whisper/stt/processing/streaming.py b/whisper/stt/processing/streaming.py new file mode 100644 index 0000000..7d2efce --- /dev/null +++ b/whisper/stt/processing/streaming.py @@ -0,0 +1,517 @@ +import json +import sys +import string +import numpy as np +from .vad import remove_non_speech +from stt import logger, USE_CTRANSLATE2, VAD, VAD_DILATATION, VAD_MIN_SPEECH_DURATION, VAD_MIN_SILENCE_DURATION +from websockets.legacy.server import WebSocketServerProtocol +from simple_websocket.ws import Server as WSServer + + +def bytes_to_array(bytes): + return np.frombuffer(bytes, dtype=np.int16).astype(np.float32) / 32768 + + +def processor_output_to_text(o): + if o[0] is None: + return "" + return o[2] + + +def whisper_to_json(o): + result = dict() + result["text"] = processor_output_to_text(o) + json_res = json.dumps(result) + return json_res + + +async def wssDecode(ws: WebSocketServerProtocol, model_and_alignementmodel): + """Async Decode function endpoint""" + res = await ws.recv() + try: + config = json.loads(res)["config"] + sample_rate = config["sample_rate"] + logger.info(f"Received config: {config}") + except Exception as e: + logger.error("Failed to read stream configuration") + await ws.close(reason="Failed to load configuration") + model, _ = model_and_alignementmodel + if USE_CTRANSLATE2: + logger.info("Using ctranslate2 for decoding") + asr = FasterWhisperASR(model=model, lan="fr") + else: + logger.info("Using whisper_timestamped for decoding") + asr = WhisperTimestampedASR(model=model, lan="fr") + online = OnlineASRProcessor( + asr, logfile=sys.stderr, buffer_trimming=8, vad=VAD, sample_rate=sample_rate, \ + dilatation=VAD_DILATATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, min_silence_duration=VAD_MIN_SILENCE_DURATION + ) + logger.info("Starting transcription ...") + while True: + try: + message = await ws.recv() + if message is None or message == "": # Timeout + logger.info("Connection closed by client") + ws.close() + except Exception as e: + logger.info(f"Connection closed by client: {e}") + break + if "eof" in str(message): + o = online.finish() + await ws.send(whisper_to_json(o)) + logger.info(f"End of stream {message}") + await ws.close(reason="End of stream") + break + online.insert_audio_chunk(bytes_to_array(message)) + o, _ = online.process_iter() + logger.info(o) + await ws.send(whisper_to_json(o)) + + +def ws_streaming(websocket_server: WSServer, model_and_alignementmodel): + """Sync Decode function endpoint""" + res = websocket_server.receive(timeout=10) + try: + config = json.loads(res)["config"] + sample_rate = config["sample_rate"] + logger.info(f"Received config: {config}") + except Exception as e: + logger.error("Failed to read stream configuration") + websocket_server.close() + model, _ = model_and_alignementmodel + if USE_CTRANSLATE2: + logger.info("Using ctranslate2 for decoding") + asr = FasterWhisperASR(model=model, lan="fr") + else: + logger.info("Using whisper_timestamped for decoding") + asr = WhisperTimestampedASR(model=model, lan="fr") + online = OnlineASRProcessor( + asr, logfile=sys.stderr, buffer_trimming=8, vad=VAD, sample_rate=sample_rate, \ + dilatation=VAD_DILATATION, min_speech_duration=VAD_MIN_SPEECH_DURATION, min_silence_duration=VAD_MIN_SILENCE_DURATION + ) + logger.info("Starting transcription ...") + while True: + try: + message = websocket_server.receive(timeout=10) + if message is None or message == "": # Timeout + logger.info("Connection closed by client") + websocket_server.close() + except Exception as e: + logger.info(f"Connection closed by client: {e}") + break + if "eof" in str(message): + o = online.finish() + websocket_server.send(whisper_to_json(o)) + logger.info(f"End of stream {message}") + websocket_server.close() + break + online.insert_audio_chunk(bytes_to_array(message)) + o, _ = online.process_iter() + websocket_server.send(whisper_to_json(o)) + + +class HypothesisBuffer: + + def __init__(self, logfile=sys.stderr): + self.commited_in_buffer = [] + self.buffer = [] + self.new = [] + + self.last_commited_time = 0 + self.last_commited_word = None + self.last_buffered_time = -1 + + self.logfile = logfile + + def insert(self, new, offset): + # compare self.commited_in_buffer and new. It inserts only the words in new that extend the commited_in_buffer, it means they are roughly behind last_commited_time and new in content + # the new tail is added to self.new + + new = [(a + offset, b + offset, t) for a, b, t in new] + self.new = [(a, b, t) for a, b, t in new if a > self.last_commited_time - 0.1] + + if len(self.new) >= 1: + a, b, t = self.new[0] + if abs(a - self.last_commited_time) < 1: + if self.commited_in_buffer: + # it's going to search for 1, 2, ..., 5 consecutive words (n-grams) that are identical in commited and new. If they are, they're dropped. + cn = len(self.commited_in_buffer) + nn = len(self.new) + for i in range(1, min(min(cn, nn), 5) + 1): # 5 is the maximum + c = " ".join( + [self.commited_in_buffer[-j][2] for j in range(1, i + 1)][ + ::-1 + ] + ) + tail = " ".join(self.new[j - 1][2] for j in range(1, i + 1)) + if c == tail: + logger.debug(f"removing last {i} words:") + for j in range(i): + logger.debug(f"\t{self.new.pop(0)}") + break + + def flush(self): + # returns commited chunk = the longest common prefix of 2 last inserts. + commit = [] + while self.new: + na, nb, nt = self.new[0] + + if len(self.buffer) == 0: + break + + if nt.lower().translate( + str.maketrans("", "", string.punctuation) + ) == self.buffer[0][2].lower().translate( + str.maketrans("", "", string.punctuation) + ): + commit.append((na, nb, nt)) + self.last_commited_word = nt + self.last_commited_time = nb + self.buffer.pop(0) + self.new.pop(0) + else: + break + self.buffer = self.new + new_non_commit = [ + i for i in self.buffer if i[1] > self.last_buffered_time - 0.1 + ] + self.last_buffered_time = self.buffer[-1][1] if self.buffer else -1 + self.new = [] + self.commited_in_buffer.extend(commit) + return commit, new_non_commit + + def pop_commited(self, time): + while self.commited_in_buffer and self.commited_in_buffer[0][1] <= time: + self.commited_in_buffer.pop(0) + + def complete(self): + return self.buffer + + +class OnlineASRProcessor: + + def __init__( + self, + asr, + buffer_trimming=15, + vad="auditok", + logfile=sys.stderr, + sample_rate=16000, + min_speech_duration=0.1, + min_silence_duration=0.1, + dilatation=0.5, + ): + """asr: WhisperASR object + tokenizer: sentence tokenizer object for the target language. Must have a method *split* that behaves like the one of MosesTokenizer. It can be None, if "segment" buffer trimming option is used, then tokenizer is not used at all. + ("segment", 15) + buffer_trimming: a pair of (option, seconds), where option is either "sentence" or "segment", and seconds is a number. Buffer is trimmed if it is longer than "seconds" threshold. Default is the most recommended option. + logfile: where to store the log. + """ + self.asr = asr + self.logfile = logfile + + self.init() + + self.buffer_trimming_sec = buffer_trimming + self.vad = vad + self.vad_dilatation = dilatation + self.vad_min_speech_duration = min_speech_duration + self.vad_min_silence_duration = min_silence_duration + self.sampling_rate = sample_rate + + def init(self): + """run this when starting or restarting processing""" + self.audio_buffer = np.array([], dtype=np.float32) + self.buffer_time_offset = 0 + + self.transcript_buffer = HypothesisBuffer(logfile=self.logfile) + self.commited = [] + self.last_chunked_at = 0 + + self.silence_iters = 0 + + def insert_audio_chunk(self, audio): + self.audio_buffer = np.append(self.audio_buffer, audio) + + def prompt(self): + """Returns a tuple: (prompt, context), where "prompt" is a 200-character suffix of commited text that is inside of the scrolled away part of audio buffer. + "context" is the commited text that is inside the audio buffer. It is transcribed again and skipped. It is returned only for debugging and logging reasons. + """ + k = max(0, len(self.commited) - 1) + while k > 0 and self.commited[k - 1][1] > self.last_chunked_at: + k -= 1 + + p = self.commited[:k] + p = [t for _, _, t in p] + prompt = [] + l = 0 + while p and l < 200: # 200 characters prompt size + x = p.pop(-1) + l += len(x) + 1 + prompt.append(x) + non_prompt = self.commited[k:] + return self.asr.sep.join(prompt[::-1]), self.asr.sep.join( + t for _, _, t in non_prompt + ) + + def process_iter(self): + """Runs on the current audio buffer. + Returns: a tuple (beg_timestamp, end_timestamp, "text"), or (None, None, ""). + The non-emty text is confirmed (committed) partial transcript. + """ + prompt, non_prompt = self.prompt() + logger.debug(f"PROMPT:{prompt}") + logger.debug(f"CONTEXT:{non_prompt}") + logger.debug( + f"Transcribing {len(self.audio_buffer)/self.sampling_rate:2.2f} seconds starting at {self.buffer_time_offset:2.2f}s" + ) + if self.vad: + np_buffer = np.array(self.audio_buffer) + audio_speech, segments, convertion_function = remove_non_speech( + np_buffer, + method=self.vad, + use_sample=True, + sample_rate=self.sampling_rate, + dilatation=self.vad_dilatation, + min_speech_duration=self.vad_min_speech_duration, + min_silence_duration=self.vad_min_silence_duration, + ) + res = self.asr.transcribe(audio_speech, init_prompt=prompt) + else: + res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt) + # transform to [(beg,end,"word1"), ...] + tsw = self.asr.ts_words(res, convertion_function if self.vad else None) + self.transcript_buffer.insert(tsw, self.buffer_time_offset) + o, buffer = self.transcript_buffer.flush() + self.commited.extend(o) + if ( + buffer + and (self.buffer_time_offset + len(self.audio_buffer) / self.sampling_rate) + - buffer[-1][1] + < 0.05 + ): + # remove the last word if it is too close to the end of the buffer + buffer.pop(-1) + logger.debug(f"New committed text:{self.to_flush(o)}") + logger.debug( + f"Buffered text:{self.to_flush(self.transcript_buffer.complete())}" + ) + + if len(self.audio_buffer) / self.sampling_rate > self.buffer_trimming_sec: + self.chunk_completed_segment( + res, + chunk_silence=self.vad, + speech_segments=segments if self.vad else False, + ) + + logger.debug( + f"Len of buffer now: {len(self.audio_buffer)/self.sampling_rate:2.2f}s" + ) + return self.to_flush(o), self.to_flush(buffer) + + def chunk_completed_segment(self, res, chunk_silence=False, speech_segments=None): + if self.commited == [] and not chunk_silence: + return + ends = self.asr.segments_end_ts(res) + t = self.commited[-1][1] + if len(ends) > 1: + e = ends[-2] + self.buffer_time_offset + while len(ends) > 2 and e > t: + ends.pop(-1) + e = ends[-2] + self.buffer_time_offset + if e <= t: + logger.debug(f"--- segment chunked at {e:2.2f}") + self.chunk_at(e) + else: + logger.debug(f"--- last segment not within commited area") + elif chunk_silence: + lenght = len(self.audio_buffer) / self.sampling_rate + e = self.buffer_time_offset + lenght - 2 + if speech_segments: + end_silence = lenght - speech_segments[-1][1] + if end_silence > 2: + logger.debug(f"--- Silence segment chunked at {e:2.2f}") + self.chunk_at(e) + elif speech_segments is not None: + logger.debug(f"--- Silence segment chunked at {e:2.2f}") + self.chunk_at(e) + else: + logger.debug(f"--- not enough segments to chunk") + + def chunk_at(self, time): + """trims the hypothesis and audio buffer at "time" """ + self.transcript_buffer.pop_commited(time) + cut_seconds = time - self.buffer_time_offset + self.audio_buffer = self.audio_buffer[int(cut_seconds * self.sampling_rate) :] + self.buffer_time_offset = time + self.last_chunked_at = time + + def words_to_sentences(self, words): + """Uses self.tokenizer for sentence segmentation of words. + Returns: [(beg,end,"sentence 1"),...] + """ + + cwords = [w for w in words] + t = " ".join(o[2] for o in cwords) + s = self.tokenizer.split(t) + out = [] + while s: + beg = None + end = None + sent = s.pop(0).strip() + fsent = sent + while cwords: + b, e, w = cwords.pop(0) + w = w.strip() + if beg is None and sent.startswith(w): + beg = b + elif end is None and sent == w: + end = e + out.append((beg, end, fsent)) + break + sent = sent[len(w) :].strip() + return out + + def finish(self): + """Flush the incomplete text when the whole processing ends. + Returns: the same format as self.process_iter() + """ + o = self.transcript_buffer.complete() + f = self.to_flush(o) + logger.debug(f"last, noncommited:{f}") + return f + + def to_flush( + self, + sents, + sep=None, + offset=0, + ): + # concatenates the timestamped words or sentences into one sequence that is flushed in one line + # sents: [(beg1, end1, "sentence1"), ...] or [] if empty + # return: (beg1,end-of-last-sentence,"concatenation of sentences") or (None, None, "") if empty + if sep is None: + sep = self.asr.sep + t = sep.join(s[2] for s in sents) + if len(sents) == 0: + b = None + e = None + else: + b = offset + sents[0][0] + e = offset + sents[-1][1] + return (b, e, t) + + +class ASRBase: + + sep = " " # join transcribe words with this character (" " for whisper_timestamped, + # "" for faster-whisper because it emits the spaces when needed) + + def __init__( + self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None + ): + self.logfile = logfile + + self.transcribe_kargs = {} + self.original_language = lan + self.model = model + + def transcribe(self, audio, init_prompt=""): + raise NotImplemented("must be implemented in the child class") + + def use_vad(self, vad_name=None): + raise NotImplemented("must be implemented in the child class") + + +class FasterWhisperASR(ASRBase): + """Uses faster-whisper library as the backend. Works much faster, appx 4-times (in offline mode). For GPU, it requires installation with a specific CUDNN version.""" + + sep = "" + + def __init__( + self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None + ): + super().__init__(lan, model=model, logfile=logfile) + self.transcribe_kargs["beam_size"] = 1 + self.transcribe_kargs["best_of"] = 1 + self.transcribe_kargs["temperature"] = 0 + self.transcribe_kargs["condition_on_previous_text"] = ( + False if condition_on_previous_text is None else condition_on_previous_text + ) + + def transcribe(self, audio, init_prompt=""): + segments, info = self.model.transcribe( + audio, + language=self.original_language, + initial_prompt=init_prompt, + word_timestamps=True, + **self.transcribe_kargs, + ) + return list(segments) + + def ts_words(self, segments, timestamps_convert_function=None): + o = [] + for segment in segments: + for word in segment.words: + # not stripping the spaces -- should not be merged with them! + w = word.word + if timestamps_convert_function is not None: + start, end = timestamps_convert_function(word.start, word.end) + t = (start, end, w) + else: + t = (word.start, word.end, w) + o.append(t) + return o + + def segments_end_ts(self, res): + return [s.end for s in res] + + +class WhisperTimestampedASR(ASRBase): + """Uses whisper_timestamped library as the backend. Initially, we tested the code on this backend. It worked, but slower than faster-whisper. + On the other hand, the installation for GPU could be easier. + """ + + sep = " " + + def __init__( + self, lan, model=None, logfile=sys.stderr, condition_on_previous_text=None + ): + super().__init__(lan, model=model, logfile=logfile) + self.transcribe_kargs["verbose"] = None + self.transcribe_kargs["beam_size"] = None + self.transcribe_kargs["best_of"] = None + self.transcribe_kargs["temperature"] = 0 + self.transcribe_kargs["condition_on_previous_text"] = ( + False if condition_on_previous_text is None else condition_on_previous_text + ) + from whisper_timestamped import transcribe_timestamped + + self.transcribe_timestamped = transcribe_timestamped + + def transcribe(self, audio, init_prompt=""): + result = self.transcribe_timestamped( + self.model, + audio, + language=self.original_language, + initial_prompt=init_prompt, + **self.transcribe_kargs, + ) + return result + + def ts_words(self, r, timestamps_convert_function=None): + # return: transcribe result object to [(beg,end,"word1"), ...] + o = [] + for s in r["segments"]: + for w in s["words"]: + if timestamps_convert_function is not None: + # print(f"start: {word.start}->{timestamps_convert_function(word.start)}, end: {word.end}->{timestamps_convert_function(word.end)}") + start, end = timestamps_convert_function(w["start"], w["end"]) + t = (start, end, w["text"]) + else: + t = (w["start"], w["end"], w["text"]) + o.append(t) + return o + + def segments_end_ts(self, res): + return [s["end"] for s in res["segments"]] diff --git a/whisper/stt/processing/vad.py b/whisper/stt/processing/vad.py new file mode 100644 index 0000000..7137410 --- /dev/null +++ b/whisper/stt/processing/vad.py @@ -0,0 +1,377 @@ +import numpy as np +import os +import shutil +from stt import logger, USE_CTRANSLATE2 + + +_silero_vad_model = {} +_has_onnx = None + + +def remove_non_speech( + audio, + use_sample=False, + min_speech_duration=0.1, + min_silence_duration=1, + dilatation=0.5, + sample_rate=16000, + method="auditok", + avoid_empty_speech=False, + return_format="tuple", +): + """ + Remove non-speech segments from audio (using Silero VAD), + glue the speech segments together and return the result along with + a function to convert timestamps from the new audio to the original audio + + parameters: + audio: torch.Tensor + audio data *in 16kHz* + use_sample: bool + if True, return start and end in samples instead of seconds + min_speech_duration: float + minimum duration (in sec) of a speech segment + min_silence_duration: float + minimum duration (in sec) of a silence segment + dilatation: float + how much (in sec) to enlarge each speech segment detected by the VAD + method: str + method to use to remove non-speech segments + avoid_empty_speech: bool + if True, avoid returning an empty speech segment (re) + """ + + if USE_CTRANSLATE2 and method == "silero": + from faster_whisper.vad import VadOptions + + options = VadOptions( + min_speech_duration_ms=min_speech_duration * 1000, + min_silence_duration_ms=min_silence_duration * 1000, + ) + from faster_whisper.vad import get_speech_timestamps + + segments = get_speech_timestamps(audio, vad_options=options) + else: + segments = get_vad_segments( + audio, + sample_rate=sample_rate, + min_speech_duration=min_speech_duration, + min_silence_duration=min_silence_duration, + method=method, + ) + segments = apply_dilatation(segments, dilatation, sample_rate, audio, output_sample=True) + segments = [(seg["start"], seg["end"]) for seg in segments] + if len(segments) == 0: + if avoid_empty_speech: + segments = [(0, audio.shape[-1])] + else: + np.array([]), [], lambda t, t2=None: t if t2 is None else [t, t2] + if not use_sample: + segments = [ + (float(s) / sample_rate, float(e) / sample_rate) for s, e in segments + ] + + if return_format == "dict": + segments = [{"start": s, "end": e} for s, e in segments] + return None, segments, lambda t, t2=None: do_convert_timestamps(segments, t, t2) + + audio_speech = np.concatenate([audio[..., s:e] for s, e in segments], axis=-1) + + return audio_speech, segments, lambda t, t2=None: do_convert_timestamps(segments, t, t2) + + +def do_convert_timestamps(segments, t, t2=None): + """ + Convert timestamp from audio without non-speech segments to original audio (with non-speech segments) + + parameters: + segments: list of tuple (start, end) corresponding to non-speech segments in original audio + t: timestamp to convert + t2: second timestamp to convert (optional), when the two timestamps should be in the same segment + """ + assert len(segments) + ioffset = 0 # Input offset + ooffset = 0 # Output offset + ipreviousend = 0 + result = [] + for istart, iend in segments: + ostart = ooffset + oend = ostart + (iend - istart) + ooffset = oend + ioffset += istart - ipreviousend + ipreviousend = iend + t_in = t <= oend + t2_in = t_in if t2 is None else t2 <= oend + if t_in or t2_in: + result.append( + [ + max(istart, min(iend, ioffset + t)), + max(istart, min(iend, ioffset + t2)) if t2 is not None else None, + ] + ) + if t_in and t2_in: + break + if not len(result): + result.append([ioffset + t, ioffset + t2 if t2 is not None else None]) + + if len(result) > 1: + # Minimize difference between durations + result = sorted(result, key=lambda x: abs(abs(t2 - t) - abs(x[1] - x[0]))) + result = result[0] + if t2 is None: + result = round(result[0], 2) + else: + result = [round(x, 2) for x in result] + return result + + +def get_vad_segments( + audio, + sample_rate=16000, + min_speech_duration=0.1, + min_silence_duration=0.1, + method="auditok", +): + """ + Get speech segments from audio using the method VAD + parameters: + audio: torch.Tensor + audio data *in 16kHz* + output_sample: bool + if True, return start and end in samples instead of seconds + min_speech_duration: float + minimum duration (in sec) of a speech segment + min_silence_duration: float + minimum duration (in sec) of a silence segment + dilatation: float + how much (in sec) to enlarge each speech segment detected by the VAD + method: str or list + VAD method to use (auditok, silero, silero:v3.1) + """ + global _silero_vad_model, _silero_get_speech_ts, _has_onnx + if isinstance(method, list): + # Explicit timestamps + segments = [ + {"start": s * sample_rate, "end": e * sample_rate} for (s, e) in method + ] + elif isinstance(method, str) and method.startswith("silero"): + version = None + _, version = check_vad_method(method, True) + # See discussion https://github.com/linto-ai/whisper-timestamped/pull/142/files#r1398326287 + need_folder_hack = version and (version < "v4") + + if _silero_vad_model.get(version) is None: + # ONNX support since 3.1 in silero + if (version is None or version >= "v3.1") and (_has_onnx is not False): + onnx = True + try: + import onnxruntime + + onnxruntime.set_default_logger_severity( + 3 + ) # Remove warning "Removing initializer 'XXX'. It is not used by any node and should be removed from the model." + _has_onnx = True + except ImportError as err: + logger.warning( + f"Please install onnxruntime to use more efficiently silero VAD" + ) + _has_onnx = False + onnx = False + else: + onnx = False + + # Choose silero version because of problems with version 4, see https://github.com/linto-ai/whisper-timestamped/issues/74 + torch_home = os.environ.get("TORCH_HOME", "~/.cache/torch") + repo_or_dir_master = os.path.expanduser( + torch_home + "/hub/snakers4_silero-vad_master" + ) + repo_or_dir_specific = ( + os.path.expanduser(torch_home + f"/hub/snakers4_silero-vad_{version}") + if version + else repo_or_dir_master + ) + repo_or_dir = repo_or_dir_specific + tmp_folder = None + + def apply_folder_hack(): + nonlocal tmp_folder + if os.path.exists(repo_or_dir_master): + tmp_folder = repo_or_dir_master + ".tmp" + shutil.move(repo_or_dir_master, tmp_folder) + # Make a symlink to the v3.1 model, otherwise it fails + input_exists = os.path.exists(repo_or_dir_specific) + if not input_exists: + # Make dummy file for the symlink to work + os.makedirs(repo_or_dir_specific, exist_ok=True) + os.symlink(repo_or_dir_specific, repo_or_dir_master) + if not input_exists: + shutil.rmtree(repo_or_dir_specific) + + source = "local" + if not os.path.exists(repo_or_dir): + # Load specific version of silero + repo_or_dir = ( + f"snakers4/silero-vad:{version}" + if version + else "snakers4/silero-vad" + ) + source = "github" + if need_folder_hack: + apply_folder_hack() + try: + from torch.hub import load as torch_load + silero_vad_model, utils = torch_load( + repo_or_dir=repo_or_dir, + model="silero_vad", + onnx=onnx, + source=source, + ) + _silero_vad_model[version] = silero_vad_model + except ImportError as err: + raise RuntimeError( + f"Please install what is needed to use the silero VAD (or use another VAD method)" + ) from err + except Exception as err: + raise RuntimeError( + f"Problem when installing silero with version {version}. Check versions here: https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models" + ) from err + finally: + if need_folder_hack: + if os.path.exists(repo_or_dir_master): + os.remove(repo_or_dir_master) + if tmp_folder: + shutil.move(tmp_folder, repo_or_dir_master) + assert os.path.isdir( + repo_or_dir_specific + ), f"Unexpected situation: missing {repo_or_dir_specific}" + + _silero_get_speech_ts = utils[0] + + # Cheap normalization of the volume + + if isinstance(audio, np.ndarray): + audio = audio / max(0.1, np.max(np.abs(audio))) + else: + audio = audio / max(0.1, audio.abs().max()) + segments = _silero_get_speech_ts( + audio, + _silero_vad_model[version], + sampling_rate=sample_rate, + min_speech_duration_ms=round(min_speech_duration * 1000), + min_silence_duration_ms=round(min_silence_duration * 1000), + return_seconds=False, + ) + + elif method == "auditok": + # Cheap normalization of the volume + if isinstance(audio, np.ndarray): + audio = audio / max(0.1, np.max(np.abs(audio))) + data = (audio * 32767).astype(np.int16).tobytes() + else: + audio = audio / max(0.1, audio.abs().max()) + data = (audio.numpy() * 32767).astype(np.int16).tobytes() + + audio_duration = len(audio) / sample_rate + from auditok import split + segments = split( + data, + sampling_rate=sample_rate, # sampling frequency in Hz + channels=1, # number of channels + sample_width=2, # number of bytes per sample + min_dur=min_speech_duration, # minimum duration of a valid audio event in seconds + max_dur=audio_duration, # maximum duration of an event + max_silence=min( + audio_duration * 0.95, min_silence_duration + ), # maximum duration of tolerated continuous silence within an event + energy_threshold=50, + drop_trailing_silence=True, + ) + + segments = [ + {"start": s._meta.start * sample_rate, "end": s._meta.end * sample_rate} + for s in segments + ] + + else: + raise ValueError(f"Got unexpected VAD method {method}") + return segments + + +def apply_dilatation(segments, dilatation, sample_rate, audio, output_sample=False): + if dilatation > 0: + dilatation = round(dilatation * sample_rate) + new_segments = [] + for seg in segments: + new_seg = { + "start": max(0, seg["start"] - dilatation), + "end": min(len(audio), seg["end"] + dilatation), + } + if len(new_segments) > 0 and new_segments[-1]["end"] >= new_seg["start"]: + new_segments[-1]["end"] = new_seg["end"] + else: + new_segments.append(new_seg) + segments = new_segments + + ratio = 1 if output_sample else 1 / sample_rate + + if ratio != 1: + for seg in segments: + seg["start"] *= ratio + seg["end"] *= ratio + if output_sample: + for seg in segments: + seg["start"] = round(seg["start"]) + seg["end"] = round(seg["end"]) + return segments + +def check_vad_method(method, with_version=False): + """ + Check whether the VAD method is valid and return the method in a consistent format + + method: str or list or True or False + """ + if method in [True, "True", "true"]: + return check_vad_method("silero") # default method + elif method in [None, False, "False", "false", "None", "none"]: + return None + elif not isinstance(method, str) and hasattr(method, "__iter__"): + # list of explicit timestamps + checked_pairs = [] + for s_e in method: + assert ( + len(s_e) == 2 + ), f"Got unexpected element {s_e} in the list of VAD segments. Expect (start, end) pairs" + checked_pairs.append(tuple(s_e)) + return checked_pairs + elif isinstance(method, str) and method.startswith("silero"): + version = None + if method != "silero": + assert method.startswith("silero:"), f"Got unexpected VAD method {method}" + version = method.split(":")[1] + if not version.startswith("v"): + version = "v" + version + try: + assert float(version[1:]) >= 1 + except: + raise ValueError( + f"Got unexpected silero version {version} (please check https://github.com/snakers4/silero-vad/wiki/Version-history-and-Available-Models)" + ) + if with_version: + return ("silero", version) + else: + return method + elif method == "auditok": + try: + import auditok + except ImportError: + raise ImportError( + "Please install auditok to use the auditok VAD (or use another VAD method)" + ) + else: + try: + method = eval(method) + assert hasattr(method, "__iter__") + except: + raise ValueError(f"Got unexpected VAD method {method}") + return check_vad_method(method, with_version=with_version) + return method