diff --git a/README.md b/README.md index 00a57062..a0ad2869 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Unlike traditional speech recognition systems that rely on continuous audio stre ## Installation - Install PyAudio and ffmpeg ```bash - bash setup.sh + bash scripts/setup.sh ``` - Install whisper-live from pip @@ -16,61 +16,81 @@ Unlike traditional speech recognition systems that rely on continuous audio stre pip install whisper-live ``` +### Setting up NVIDIA/TensorRT-LLM for TensorRT backend +- Please follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) for setup of [NVIDIA/TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM) and for building Whisper-TensorRT engine. + ## Getting Started -- Run the server +The server supports two backends `faster_whisper` and `tensorrt`. If running `tensorrt` backend follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) + +### Running the Server +- [Faster Whisper](https://github.com/SYSTRAN/faster-whisper) backend +```bash +python3 run_server.py --port 9090 \ + --backend faster_whisper + +# running with custom model +python3 run_server.py --port 9090 \ + --backend faster_whisper + -fw "/path/to/custom/faster/whisper/model" +``` + +- TensorRT backend. Currently, we recommend to only use the docker setup for TensorRT. Follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) which works as expected. Make sure to build your TensorRT Engines before running the server with TensorRT backend. +```bash +# Run English only model +python3 run_server.py -p 9090 \ + -b tensorrt \ + -trt /home/TensorRT-LLM/examples/whisper/whisper_small_en + +# Run Multilingual model +python3 run_server.py -p 9090 \ + -b tensorrt \ + -trt /home/TensorRT-LLM/examples/whisper/whisper_small \ + -m +``` + + +### Running the Client +- To transcribe an audio file: ```python - from whisper_live.server import TranscriptionServer - server = TranscriptionServer() - server.run("0.0.0.0", 9090) +from whisper_live.client import TranscriptionClient +client = TranscriptionClient( + "localhost", + 9090, + is_multilingual=False, + lang="en", + translate=False, + model_size="small" +) + +client("tests/jfk.wav") ``` +This command transcribes the specified audio file (audio.wav) using the Whisper model. It connects to the server running on localhost at port 9090. It can also enable the multilingual feature, allowing transcription in multiple languages. The language option specifies the target language for transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language. -- On the client side - - To transcribe an audio file: - ```python - from whisper_live.client import TranscriptionClient - client = TranscriptionClient( - "localhost", - 9090, - is_multilingual=False, - lang="en", - translate=False, - model_size="small" - ) - - client("tests/jfk.wav") - ``` - This command transcribes the specified audio file (audio.wav) using the Whisper model. It connects to the server running on localhost at port 9090. It can also enable the multilingual feature, allowing transcription in multiple languages. The language option specifies the target language for transcription, in this case, English ("en"). The translate option should be set to `True` if we want to translate from the source language to English and `False` if we want to transcribe in the source language. - - - To transcribe from microphone: - ```python - from whisper_live.client import TranscriptionClient - client = TranscriptionClient( - "localhost", - 9090, - is_multilingual=True, - lang="hi", - translate=True, - model_size="small" - ) - client() - ``` - This command captures audio from the microphone and sends it to the server for transcription. It uses the multilingual option with `hi` as the selected language, enabling the multilingual feature and specifying the target language and task. We use whisper `small` by default but can be changed to any other option based on the requirements and the hardware running the server. - - - To transcribe from a HLS stream: - ```python - client = TranscriptionClient(host, port, is_multilingual=True, lang="en", translate=False) - client(hls_url="http://as-hls-ww-live.akamaized.net/pool_904/live/ww/bbc_1xtra/bbc_1xtra.isml/bbc_1xtra-audio%3d96000.norewind.m3u8") - ``` - This command streams audio into the server from a HLS stream. It uses the same options as the previous command, enabling the multilingual feature and specifying the target language and task. +- To transcribe from microphone: +```python +from whisper_live.client import TranscriptionClient +client = TranscriptionClient( + "localhost", + 9090, + is_multilingual=True, + lang="hi", + translate=True, + model_size="small" +) +client() +``` +This command captures audio from the microphone and sends it to the server for transcription. It uses the multilingual option with `hi` as the selected language, enabling the multilingual feature and specifying the target language and task. We use whisper `small` by default but can be changed to any other option based on the requirements and the hardware running the server. -## Transcribe audio from browser -- Run the server +- To transcribe from a HLS stream: ```python - from whisper_live.server import TranscriptionServer - server = TranscriptionServer() - server.run("0.0.0.0", 9090) +from whisper_live.client import TranscriptionClient +client = TranscriptionClient(host, port, is_multilingual=True, lang="en", translate=False) +client(hls_url="http://as-hls-ww-live.akamaized.net/pool_904/live/ww/bbc_1xtra/bbc_1xtra.isml/bbc_1xtra-audio%3d96000.norewind.m3u8") ``` -This would start the websocket server on port ```9090```. +This command streams audio into the server from a HLS stream. It uses the same options as the previous command, enabling the multilingual feature and specifying the target language and task. + +## Transcribe audio from browser +- Run the server with your desired backend as shown [here](https://github.com/collabora/WhisperLive?tab=readme-ov-file#running-the-server) ### Chrome Extension - Refer to [Audio-Transcription-Chrome](https://github.com/collabora/whisper-live/tree/main/Audio-Transcription-Chrome#readme) to use Chrome extension. @@ -80,21 +100,24 @@ This would start the websocket server on port ```9090```. ## Whisper Live Server in Docker - GPU -```bash - docker build . -t whisper-live -f docker/Dockerfile.gpu - docker run -it --gpus all -p 9090:9090 whisper-live:latest -``` + - Faster-Whisper + ```bash + docker build . -t whisper-live -f docker/Dockerfile.gpu + docker run -it --gpus all -p 9090:9090 whisper-live:latest + ``` + + - TensorRT. Follow [TensorRT_whisper readme](https://github.com/collabora/WhisperLive/blob/main/TensorRT_whisper.md) in order to setup docker and use TensorRT backend. We provide a pre-built docker image which has TensorRT-LLM built and ready to use. - CPU ```bash - docker build . -t whisper-live -f docker/Dockerfile.cpu - docker run -it -p 9090:9090 whisper-live:latest +docker build . -t whisper-live -f docker/Dockerfile.cpu +docker run -it -p 9090:9090 whisper-live:latest ``` **Note**: By default we use "small" model size. To build docker image for a different model size, change the size in server.py and then build the docker image. ## Future Work - [ ] Add translation to other languages on top of transcription. -- [ ] TensorRT backend for Whisper. +- [x] TensorRT backend for Whisper. ## Contact diff --git a/TensorRT_whisper.md b/TensorRT_whisper.md new file mode 100644 index 00000000..2935363f --- /dev/null +++ b/TensorRT_whisper.md @@ -0,0 +1,66 @@ +# Whisper-TensorRT +We have only tested the TensorRT backend in docker so, we recommend docker for a smooth TensorRT backend setup. +**Note**: We use [our fork to setup TensorRT](https://github.com/makaveli10/TensorRT-LLM) + +## Installation +- Install [docker](https://docs.docker.com/engine/install/) +- Install [nvidia-container-toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) + +- Clone this repo. +```bash +git clone https://github.com/collabora/WhisperLive.git +cd WhisperLive +``` + +- Pull the TensorRT-LLM docker image which we prebuilt for WhisperLive TensorRT backend. +```bash +docker pull ghcr.io/collabora/whisperbot-base:latest +``` + +- Next, we run the docker image and mount WhisperLive repo to the containers `/home` directory. +```bash +docker run -it --gpus all --shm-size=8g \ + --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \ + -v /path/to/WhisperLive:/home/WhisperLive \ + ghcr.io/collabora/whisperbot-base:latest +``` + +- Make sure to test the installation. +```bash +# export ENV=${ENV:-/etc/shinit_v2} +# source $ENV +python -c "import torch; import tensorrt; import tensorrt_llm" +``` +**NOTE**: Uncomment and update library paths if imports fail. + +## Whisper TensorRT Engine +- We build `small.en` and `small` multilingual TensorRT engine. The script logs the path of the directory with Whisper TensorRT engine. We need the model_path to run the server. +```bash +# convert small.en +bash build_whisper_tensorrt /root/TensorRT-LLM-examples small.en + +# convert small multilingual model +bash build_whisper_tensorrt /root/TensorRT-LLM-examples small +``` + +## Run WhisperLive Server with TensorRT Backend +```bash +cd /home/WhisperLive + +# Install requirements +pip install -r requirements/server.txt + +# Required to create mel spectogram +wget --directory-prefix=assets assets/mel_filters.npz https://raw.githubusercontent.com/openai/whisper/main/whisper/assets/mel_filters.npz + +# Run English only model +python3 run_server.py --port 9090 \ + --backend tensorrt \ + --trt_model_path "path/to/whisper_trt/from/build/step" + +# Run Multilingual model +python3 run_server.py --port 9090 \ + --backend tensorrt \ + --trt_model_path "path/to/whisper_trt/from/build/step" \ + --trt_multilingual +``` diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index d495f6f6..d1958cfe 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -33,7 +33,7 @@ RUN apt install python3-pip -y RUN mkdir /app WORKDIR /app -COPY setup.sh /app +COPY scripts/setup.sh /app COPY requirements/ /app RUN bash setup.sh diff --git a/docker/Dockerfile.gpu b/docker/Dockerfile.gpu index cccd17ab..3e506567 100644 --- a/docker/Dockerfile.gpu +++ b/docker/Dockerfile.gpu @@ -33,7 +33,7 @@ RUN apt install python3-pip -y RUN mkdir /app WORKDIR /app -COPY setup.sh /app +COPY scripts/setup.sh /app COPY requirements/ /app RUN apt update --fix-missing diff --git a/requirements/server.txt b/requirements/server.txt index 62a292e7..77e53788 100644 --- a/requirements/server.txt +++ b/requirements/server.txt @@ -1,7 +1,5 @@ -PyAudio faster-whisper==0.10.0 ---extra-index-url https://download.pytorch.org/whl/cu111 -torch==1.10.1 -torchaudio==0.10.1 +torch websockets -onnxruntime==1.16.0 \ No newline at end of file +onnxruntime==1.16.0 +numba \ No newline at end of file diff --git a/run_server.py b/run_server.py index 0c2d5de3..3feeec6b 100644 --- a/run_server.py +++ b/run_server.py @@ -2,12 +2,37 @@ from whisper_live.server import TranscriptionServer if __name__ == "__main__": - server = TranscriptionServer() parser = argparse.ArgumentParser() - parser.add_argument('--model_path', type=str, default=None, help="Custom Faster Whisper Model") + parser.add_argument('--port', '-p', + type=int, + default=9090, + help="Websocket port to run the server on.") + parser.add_argument('--backend', '-b', + type=str, + default='faster_whisper', + help='Backends from ["tensorrt", "faster_whisper"]') + parser.add_argument('--faster_whisper_custom_model_path', '-fw', + type=str, default=None, + help="Custom Faster Whisper Model") + parser.add_argument('--trt_model_path', '-trt', + type=str, + default=None, + help='Whisper TensorRT model path') + parser.add_argument('--trt_multilingual', '-m', + action="store_true", + help='Boolean only for TensorRT model. True if multilingual.') args = parser.parse_args() + + if args.backend == "tensorrt": + if args.trt_model_path is None: + raise ValueError("Please Provide a valid tensorrt model path") + + server = TranscriptionServer() server.run( "0.0.0.0", - 9090, - custom_model_path=args.model_path + port=args.port, + backend=args.backend, + faster_whisper_custom_model_path=args.faster_whisper_custom_model_path, + whisper_tensorrt_path=args.trt_model_path, + trt_multilingual=args.trt_multilingual ) diff --git a/scripts/build_whisper_tensorrt.sh b/scripts/build_whisper_tensorrt.sh new file mode 100644 index 00000000..98248039 --- /dev/null +++ b/scripts/build_whisper_tensorrt.sh @@ -0,0 +1,77 @@ +#!/bin/bash + +download_and_build_model() { + local model_name="$1" + local model_url="" + + case "$model_name" in + "tiny.en") + model_url="https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt" + ;; + "tiny") + model_url="https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt" + ;; + "base.en") + model_url="https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt" + ;; + "base") + model_url="https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt" + ;; + "small.en") + model_url="https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt" + ;; + "small") + model_url="https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt" + ;; + "medium.en") + model_url="https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt" + ;; + "medium") + model_url="https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt" + ;; + "large-v1") + model_url="https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt" + ;; + "large-v2") + model_url="https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt" + ;; + "large-v3" | "large") + model_url="https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt" + ;; + *) + echo "Invalid model name: $model_name" + exit 1 + ;; + esac + + echo "Downloading $model_name..." + # wget --directory-prefix=assets "$model_url" + # echo "Download completed: ${model_name}.pt" + if [ ! -f "assets/${model_name}.pt" ]; then + wget --directory-prefix=assets "$model_url" + echo "Download completed: ${model_name}.pt" + else + echo "${model_name}.pt already exists in assets directory." + fi + + local output_dir="whisper_${model_name//./_}" + echo "$output_dir" + echo "Running build script for $model_name with output directory $output_dir" + python3 build.py --output_dir "$output_dir" --use_gpt_attention_plugin --use_gemm_plugin --use_bert_attention_plugin --model_name "$model_name" + echo "Whisper $model_name TensorRT engine built." + echo "=========================================" + echo "Model is located at: $(pwd)/$output_dir" +} + +if [ "$#" -lt 1 ]; then + echo "Usage: $0 [model-name]" + exit 1 +fi + +tensorrt_examples_dir="$1" +model_name="${2:-small.en}" + +cd $1/whisper +pip install --no-deps -r requirements.txt + +download_and_build_model "$model_name" diff --git a/setup.sh b/scripts/setup.sh similarity index 100% rename from setup.sh rename to scripts/setup.sh diff --git a/whisper_live/client.py b/whisper_live/client.py index 8b08d159..96971c96 100644 --- a/whisper_live/client.py +++ b/whisper_live/client.py @@ -72,7 +72,7 @@ def __init__( lang (str, optional): The selected language for transcription when multilingual is disabled. Default is None. translate (bool, optional): Specifies if the task is translation. Default is False. """ - self.chunk = 1024 + self.chunk = 4096 self.format = pyaudio.paInt16 self.channels = 1 self.rate = 16000 diff --git a/whisper_live/server.py b/whisper_live/server.py index 56e8a64d..be0e9598 100644 --- a/whisper_live/server.py +++ b/whisper_live/server.py @@ -12,10 +12,18 @@ import torch import numpy as np -import time +import queue + +from whisper_live.vad import VoiceActivityDetection +from scipy.io.wavfile import write import functools +from whisper_live.vad import VoiceActivityDetection from whisper_live.transcriber import WhisperModel +try: + from whisper_live.transcriber_tensorrt import WhisperTRTLLM +except Exception as e: + logging.warn("cannot import WhisperTRTLLM") class TranscriptionServer: @@ -37,7 +45,7 @@ class TranscriptionServer: def __init__(self): # voice activity detection model - + self.clients = {} self.websockets = {} self.clients_start_time = {} @@ -61,7 +69,12 @@ def get_wait_time(self): return wait_time / 60 - def recv_audio(self, websocket, custom_model_path=None): + def recv_audio(self, + websocket, + backend="faster_whisper", + faster_whisper_custom_model_path=None, + whisper_tensorrt_path=None, + trt_multilingual=False): """ Receive audio chunks from a client in an infinite loop. @@ -78,10 +91,19 @@ def recv_audio(self, websocket, custom_model_path=None): Args: websocket (WebSocket): The WebSocket connection for the client. - + backend (str): The backend to run the server with. + faster_whisper_custom_model_path (str): path to custom faster whisper model. + whisper_tensorrt_path (str): Required for tensorrt backend. + trt_multilingual(bool): Only used for tensorrt, True if multilingual model. + Raises: Exception: If there is an error during the audio frame processing. """ + self.backend = backend + if self.backend == "tensorrt": + self.vad_model = VoiceActivityDetection() + self.vad_threshold = 0.5 + logging.info("New client connected") options = websocket.recv() options = json.loads(options) @@ -98,31 +120,77 @@ def recv_audio(self, websocket, custom_model_path=None): websocket.close() del websocket return - - # validate custom model - if custom_model_path is not None and os.path.exists(custom_model_path): - logging.info(f"Using custom model {custom_model_path}") - options["model"] = custom_model_path - - client = ServeClient( - websocket, - multilingual=options["multilingual"], - language=options["language"], - task=options["task"], - client_uid=options["uid"], - model=options["model"], - initial_prompt=options.get("initial_prompt"), - vad_parameters=options.get("vad_parameters"), - ) + + if self.backend == "tensorrt": + try: + import tensorrt as trt + import tensorrt_llm + self.backend = "tensorrt" + client = ServeClientTensorRT( + websocket, + multilingual=trt_multilingual, + language=options["language"], + task=options["task"], + client_uid=options["uid"], + model=whisper_tensorrt_path + ) + logging.info(f"Running TensorRT backend.") + except Exception as e: + websocket.send( + json.dumps( + { + "uid": self.client_uid, + "status": "ERROR", + "message": f"TensorRT-LLM not supported on Server yet. Reverting to available backend: 'faster_whisper'" + } + ) + ) + self.backend = "faster_whisper" + + if self.backend == "faster_whisper": + # validate custom model + if faster_whisper_custom_model_path is not None and os.path.exists(faster_whisper_custom_model_path): + logging.info(f"Using custom model {faster_whisper_custom_model_path}") + options["model"] = faster_whisper_custom_model_path + client = ServeClientFasterWhisper( + websocket, + multilingual=options["multilingual"], + language=options["language"], + task=options["task"], + client_uid=options["uid"], + model=options["model"], + initial_prompt=options.get("initial_prompt"), + vad_parameters=options.get("vad_parameters") + ) + logging.info(f"Running faster_whisper backend.") self.clients[websocket] = client self.clients_start_time[websocket] = time.time() + no_voice_activity_chunks = 0 while True: try: frame_data = websocket.recv() frame_np = np.frombuffer(frame_data, dtype=np.float32) + # VAD, for faster_whisper VAD model is already integrated + if self.backend == "tensorrt": + try: + speech_prob = self.vad_model(torch.from_numpy(frame_np.copy()), self.RATE).item() + if speech_prob < self.vad_threshold: + no_voice_activity_chunks += 1 + if no_voice_activity_chunks > 3: + if not self.clients[websocket].eos: + self.clients[websocket].set_eos(True) + time.sleep(0.1) # Sleep 100m; wait some voice activity. + continue + no_voice_activity_chunks = 0 + self.clients[websocket].set_eos(False) + + except Exception as e: + logging.error(e) + return + self.clients[websocket].add_frames(frame_np) elapsed_time = time.time() - self.clients_start_time[websocket] @@ -137,15 +205,21 @@ def recv_audio(self, websocket, custom_model_path=None): break except Exception as e: - logging.info(f"[ERROR]: Client with uid '{self.clients[websocket].client_uid}' Disconnected.") - if self.clients[websocket].model_size is not None: - self.clients[websocket].cleanup() + logging.error(e) + self.clients[websocket].cleanup() self.clients.pop(websocket) self.clients_start_time.pop(websocket) del websocket break - def run(self, host, port=9090, custom_model_path=None): + def run(self, + host, + port=9090, + backend="tensorrt", + faster_whisper_custom_model_path=None, + whisper_tensorrt_path=None, + trt_multilingual=False + ): """ Run the transcription server. @@ -156,7 +230,10 @@ def run(self, host, port=9090, custom_model_path=None): with serve( functools.partial( self.recv_audio, - custom_model_path=custom_model_path + backend=backend, + faster_whisper_custom_model_path=faster_whisper_custom_model_path, + whisper_tensorrt_path=whisper_tensorrt_path, + trt_multilingual=trt_multilingual ), host, port @@ -164,7 +241,97 @@ def run(self, host, port=9090, custom_model_path=None): server.serve_forever() -class ServeClient: +class ServeClientBase(object): + RATE = 16000 + SERVER_READY = "SERVER_READY" + DISCONNECT = "DISCONNECT" + + def __init__(self, client_uid, websocket): + self.client_uid = client_uid + self.websocket = websocket + self.data = b"" + self.frames = b"" + self.timestamp_offset = 0.0 + self.frames_np = None + self.frames_offset = 0.0 + self.text = [] + self.current_out = '' + self.prev_out = '' + self.t_start=None + self.exit = False + self.same_output_threshold = 0 + self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds + self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds + self.transcript = [] + self.send_last_n_segments = 10 + + # text formatting + self.wrapper = textwrap.TextWrapper(width=50) + self.pick_previous_segments = 2 + + # threading + self.lock = threading.Lock() + + def add_frames(self, frame_np): + """ + Add audio frames to the ongoing audio stream buffer. + + This method is responsible for maintaining the audio stream buffer, allowing the continuous addition + of audio frames as they are received. It also ensures that the buffer does not exceed a specified size + to prevent excessive memory usage. + + If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds + of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided + audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. + + Args: + frame_np (numpy.ndarray): The audio frame data as a NumPy array. + + """ + self.lock.acquire() + if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE: + self.frames_offset += 30.0 + self.frames_np = self.frames_np[int(30*self.RATE):] + if self.frames_np is None: + self.frames_np = frame_np.copy() + else: + self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) + self.lock.release() + + def speech_to_text(self): + raise NotImplementedError("Please implement in child Class.") + + def disconnect(self): + """ + Notify the client of disconnection and send a disconnect message. + + This method sends a disconnect message to the client via the WebSocket connection to notify them + that the transcription service is disconnecting gracefully. + + """ + self.websocket.send( + json.dumps( + { + "uid": self.client_uid, + "message": self.DISCONNECT + } + ) + ) + + def cleanup(self): + """ + Perform cleanup tasks before exiting the transcription service. + + This method performs necessary cleanup tasks, including stopping the transcription thread, marking + the exit flag to indicate the transcription thread should exit gracefully, and destroying resources + associated with the transcription process. + + """ + logging.info("Cleaning up.") + self.exit = True + + +class ServeClientTensorRT(ServeClientBase): """ Attributes: RATE (int): The audio sampling rate (constant) set to 16000. @@ -193,10 +360,191 @@ class ServeClient: pick_previous_segments (int): Number of previous segments to include in the output. websocket: The WebSocket connection for the client. """ - RATE = 16000 - SERVER_READY = "SERVER_READY" - DISCONNECT = "DISCONNECT" + def __init__( + self, + websocket, + task="transcribe", + device=None, + multilingual=False, + language=None, + client_uid=None, + model=None + ): + """ + Initialize a ServeClient instance. + The Whisper model is initialized based on the client's language and device availability. + The transcription thread is started upon initialization. A "SERVER_READY" message is sent + to the client to indicate that the server is ready. + + Args: + websocket (WebSocket): The WebSocket connection for the client. + task (str, optional): The task type, e.g., "transcribe." Defaults to "transcribe". + device (str, optional): The device type for Whisper, "cuda" or "cpu". Defaults to None. + multilingual (bool, optional): Whether the client supports multilingual transcription. Defaults to False. + language (str, optional): The language for transcription. Defaults to None. + client_uid (str, optional): A unique identifier for the client. Defaults to None. + + """ + super().__init__(client_uid, websocket) + self.language = language if multilingual else "en" + self.task = task + self.eos = False + self.transcriber = WhisperTRTLLM( + model, + assets_dir="assets", + device="cuda", + is_multilingual=multilingual, + language=self.language, + task=self.task + ) + + # threading + self.trans_thread = threading.Thread(target=self.speech_to_text) + self.trans_thread.start() + + self.websocket.send( + json.dumps( + { + "uid": self.client_uid, + "message": self.SERVER_READY + } + ) + ) + + def set_eos(self, eos): + self.lock.acquire() + self.eos = eos + self.lock.release() + + def add_frames(self, frame_np): + """ + Add audio frames to the ongoing audio stream buffer. + + This method is responsible for maintaining the audio stream buffer, allowing the continuous addition + of audio frames as they are received. It also ensures that the buffer does not exceed a specified size + to prevent excessive memory usage. + + If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds + of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided + audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. + + Args: + frame_np (numpy.ndarray): The audio frame data as a NumPy array. + + """ + self.lock.acquire() + if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE: + self.frames_offset += 30.0 + self.frames_np = self.frames_np[int(30*self.RATE):] + if self.frames_np is None: + self.frames_np = frame_np.copy() + else: + self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) + self.lock.release() + + def speech_to_text(self): + """ + Process an audio stream in an infinite loop, continuously transcribing the speech. + + This method continuously receives audio frames, performs real-time transcription, and sends + transcribed segments to the client via a WebSocket connection. + + If the client's language is not detected, it waits for 30 seconds of audio input to make a language prediction. + It utilizes the Whisper ASR model to transcribe the audio, continuously processing and streaming results. Segments + are sent to the client in real-time, and a history of segments is maintained to provide context.Pauses in speech + (no output from Whisper) are handled by showing the previous output for a set duration. A blank segment is added if + there is no speech for a specified duration to indicate a pause. + + Raises: + Exception: If there is an issue with audio processing or WebSocket communication. + + """ + while True: + if self.exit: + logging.info("Exiting speech to text thread") + break + + if self.frames_np is None: + time.sleep(0.02) # wait for any audio to arrive + continue + + # clip audio if the current chunk exceeds 30 seconds, this basically implies that + # no valid segment for the last 30 seconds from whisper + if self.frames_np[int((self.timestamp_offset - self.frames_offset)*self.RATE):].shape[0] > 25 * self.RATE: + duration = self.frames_np.shape[0] / self.RATE + self.timestamp_offset = self.frames_offset + duration - 5 + + samples_take = max(0, (self.timestamp_offset - self.frames_offset)*self.RATE) + input_bytes = self.frames_np[int(samples_take):].copy() + duration = input_bytes.shape[0] / self.RATE + if duration<0.4: + continue + + try: + input_sample = input_bytes.copy() + logging.info(f"[WhisperTensorRT:] Processing audio with duration: {duration}") + mel, duration = self.transcriber.log_mel_spectrogram(input_sample) + last_segment = self.transcriber.transcribe(mel) + segments = [] + if len(last_segment): + if len(self.transcript) < self.send_last_n_segments: + segments = self.transcript[:].copy() + else: + segments = self.transcript[-self.send_last_n_segments:].copy() + if last_segment is not None: + segments.append({"text": last_segment}) + try: + self.websocket.send( + json.dumps({ + "uid": self.client_uid, + "segments": segments, + }) + ) + + if self.eos: + if not len(self.transcript): + self.transcript.append({"text": last_segment + " "}) + elif self.transcript[-1]["text"].strip() != last_segment: + self.transcript.append({"text": last_segment + " "}) + self.timestamp_offset += duration + + + except Exception as e: + logging.error(f"[ERROR]: {e}") + + except Exception as e: + logging.error(f"[ERROR]: {e}") + +class ServeClientFasterWhisper(ServeClientBase): + """ + Attributes: + RATE (int): The audio sampling rate (constant) set to 16000. + SERVER_READY (str): A constant message indicating that the server is ready. + DISCONNECT (str): A constant message indicating that the client should disconnect. + client_uid (str): A unique identifier for the client. + data (bytes): Accumulated audio data. + frames (bytes): Accumulated audio frames. + language (str): The language for transcription. + task (str): The task type, e.g., "transcribe." + transcriber (WhisperModel): The Whisper model for speech-to-text. + timestamp_offset (float): The offset in audio timestamps. + frames_np (numpy.ndarray): NumPy array to store audio frames. + frames_offset (float): The offset in audio frames. + text (list): List of transcribed text segments. + current_out (str): The current incomplete transcription. + prev_out (str): The previous incomplete transcription. + t_start (float): Timestamp for the start of transcription. + exit (bool): A flag to exit the transcription thread. + same_output_threshold (int): Threshold for consecutive same output segments. + show_prev_out_thresh (int): Threshold for showing previous output segments. + add_pause_thresh (int): Threshold for adding a pause (blank) segment. + transcript (list): List of transcribed segments. + send_last_n_segments (int): Number of last segments to send to the client. + wrapper (textwrap.TextWrapper): Text wrapper for formatting text. + pick_previous_segments (int): Number of previous segments to include in the output. + websocket: The WebSocket connection for the client. + """ def __init__( self, websocket, @@ -224,9 +572,7 @@ def __init__( client_uid (str, optional): A unique identifier for the client. Defaults to None. """ - self.client_uid = client_uid - self.data = b"" - self.frames = b"" + super().__init__(client_uid, websocket) self.model_sizes = [ "tiny", "tiny.en", "base", "base.en", "small", "small.en", "medium", "medium.en", "large-v2", "large-v3", @@ -237,42 +583,22 @@ def __init__( self.model_size_or_path = self.get_model_size(model) else: self.model_size_or_path = model - self.language = language if self.multilingual else "en" self.task = task - self.websocket = websocket self.initial_prompt = initial_prompt self.vad_parameters = vad_parameters or {"threshold": 0.5} device = "cuda" if torch.cuda.is_available() else "cpu" - if self.model_size_or_path is None: + if self.model_size_or_path == None: return self.transcriber = WhisperModel( - self.model_size_or_path, + self.model_size_or_path, device=device, compute_type="int8" if device=="cpu" else "float16", local_files_only=False, ) - - self.timestamp_offset = 0.0 - self.frames_np = None - self.frames_offset = 0.0 - self.text = [] - self.current_out = '' - self.prev_out = '' - self.t_start=None - self.exit = False - self.same_output_threshold = 0 - self.show_prev_out_thresh = 5 # if pause(no output from whisper) show previous output for 5 seconds - self.add_pause_thresh = 3 # add a blank to segment list as a pause(no speech) for 3 seconds - self.transcript = [] - self.send_last_n_segments = 10 - - # text formatting - self.wrapper = textwrap.TextWrapper(width=50) - self.pick_previous_segments = 2 # threading self.trans_thread = threading.Thread(target=self.speech_to_text) @@ -312,30 +638,6 @@ def get_model_size(self, model_size): return model_size - def add_frames(self, frame_np): - """ - Add audio frames to the ongoing audio stream buffer. - - This method is responsible for maintaining the audio stream buffer, allowing the continuous addition - of audio frames as they are received. It also ensures that the buffer does not exceed a specified size - to prevent excessive memory usage. - - If the buffer size exceeds a threshold (45 seconds of audio data), it discards the oldest 30 seconds - of audio data to maintain a reasonable buffer size. If the buffer is empty, it initializes it with the provided - audio frame. The audio stream buffer is used for real-time processing of audio data for transcription. - - Args: - frame_np (numpy.ndarray): The audio frame data as a NumPy array. - - """ - if self.frames_np is not None and self.frames_np.shape[0] > 45*self.RATE: - self.frames_offset += 30.0 - self.frames_np = self.frames_np[int(30*self.RATE):] - if self.frames_np is None: - self.frames_np = frame_np.copy() - else: - self.frames_np = np.concatenate((self.frames_np, frame_np), axis=0) - def speech_to_text(self): """ Process an audio stream in an infinite loop, continuously transcribing the speech. @@ -509,32 +811,3 @@ def update_segments(self, segments, duration): self.timestamp_offset += offset return last_segment - - def disconnect(self): - """ - Notify the client of disconnection and send a disconnect message. - - This method sends a disconnect message to the client via the WebSocket connection to notify them - that the transcription service is disconnecting gracefully. - - """ - self.websocket.send( - json.dumps( - { - "uid": self.client_uid, - "message": self.DISCONNECT - } - ) - ) - - def cleanup(self): - """ - Perform cleanup tasks before exiting the transcription service. - - This method performs necessary cleanup tasks, including stopping the transcription thread, marking - the exit flag to indicate the transcription thread should exit gracefully, and destroying resources - associated with the transcription process. - - """ - logging.info("Cleaning up.") - self.exit = True diff --git a/whisper_live/tensorrt_utils.py b/whisper_live/tensorrt_utils.py new file mode 100644 index 00000000..7b21010d --- /dev/null +++ b/whisper_live/tensorrt_utils.py @@ -0,0 +1,365 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from collections import defaultdict +from functools import lru_cache +from pathlib import Path +from subprocess import CalledProcessError, run +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union + +import kaldialign +import numpy as np +import soundfile +import torch +import torch.nn.functional as F + +Pathlike = Union[str, Path] + +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk + + +def load_audio(file: str, sr: int = SAMPLE_RATE): + """ + Open an audio file and read as mono waveform, resampling as necessary + + Parameters + ---------- + file: str + The audio file to open + + sr: int + The sample rate to resample the audio if necessary + + Returns + ------- + A NumPy array containing the audio waveform, in float32 dtype. + """ + + # This launches a subprocess to decode audio while down-mixing + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # fmt: off + cmd = [ + "ffmpeg", "-nostdin", "-threads", "0", "-i", file, "-f", "s16le", "-ac", + "1", "-acodec", "pcm_s16le", "-ar", + str(sr), "-" + ] + # fmt: on + try: + out = run(cmd, capture_output=True, check=True).stdout + except CalledProcessError as e: + raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e + + return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 + + +def load_audio_wav_format(wav_path): + # make sure audio in .wav format + assert wav_path.endswith( + '.wav'), f"Only support .wav format, but got {wav_path}" + waveform, sample_rate = soundfile.read(wav_path) + assert sample_rate == 16000, f"Only support 16k sample rate, but got {sample_rate}" + return waveform, sample_rate + + +def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1): + """ + Pad or trim the audio array to N_SAMPLES, as expected by the encoder. + """ + if torch.is_tensor(array): + if array.shape[axis] > length: + array = array.index_select(dim=axis, + index=torch.arange(length, + device=array.device)) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = F.pad(array, + [pad for sizes in pad_widths[::-1] for pad in sizes]) + else: + if array.shape[axis] > length: + array = array.take(indices=range(length), axis=axis) + + if array.shape[axis] < length: + pad_widths = [(0, 0)] * array.ndim + pad_widths[axis] = (0, length - array.shape[axis]) + array = np.pad(array, pad_widths) + + return array + + +@lru_cache(maxsize=None) +def mel_filters(device, + n_mels: int, + mel_filters_dir: str = None) -> torch.Tensor: + """ + load the mel filterbank matrix for projecting STFT into a Mel spectrogram. + Allows decoupling librosa dependency; saved using: + + np.savez_compressed( + "mel_filters.npz", + mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80), + ) + """ + assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}" + if mel_filters_dir is None: + mel_filters_path = os.path.join(os.path.dirname(__file__), "assets", + "mel_filters.npz") + else: + mel_filters_path = os.path.join(mel_filters_dir, "mel_filters.npz") + with np.load(mel_filters_path) as f: + return torch.from_numpy(f[f"mel_{n_mels}"]).to(device) + + +def log_mel_spectrogram( + audio: Union[str, np.ndarray, torch.Tensor], + n_mels: int, + padding: int = 0, + device: Optional[Union[str, torch.device]] = None, + return_duration: bool = False, + mel_filters_dir: str = None, +): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 and 128 are supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80 or 128, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + if audio.endswith('.wav'): + audio, _ = load_audio_wav_format(audio) + else: + audio = load_audio(audio) + assert isinstance(audio, + np.ndarray), f"Unsupported audio type: {type(audio)}" + duration = audio.shape[-1] / SAMPLE_RATE + audio = pad_or_trim(audio, N_SAMPLES) + audio = audio.astype(np.float32) + audio = torch.from_numpy(audio) + + if device is not None: + audio = audio.to(device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, + N_FFT, + HOP_LENGTH, + window=window, + return_complex=True) + magnitudes = stft[..., :-1].abs()**2 + + filters = mel_filters(audio.device, n_mels, mel_filters_dir) + mel_spec = filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if return_duration: + return log_spec, duration + else: + return log_spec + + +def store_transcripts(filename: Pathlike, texts: Iterable[Tuple[str, str, + str]]) -> None: + """Save predicted results and reference transcripts to a file. + https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + Args: + filename: + File to save the results to. + texts: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + Returns: + Return None. + """ + with open(filename, "w") as f: + for cut_id, ref, hyp in texts: + print(f"{cut_id}:\tref={ref}", file=f) + print(f"{cut_id}:\thyp={hyp}", file=f) + + +def write_error_stats( + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, +) -> float: + """Write statistics based on predicted results and reference transcripts. + https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py + It will write the following to the given file: + + - WER + - number of insertions, deletions, substitutions, corrects and total + reference words. For example:: + + Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606 + reference words (2337 correct) + + - The difference between the reference transcript and predicted result. + An instance is given below:: + + THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES + + The above example shows that the reference word is `EDISON`, + but it is predicted to `ADDISON` (a substitution error). + + Another example is:: + + FOR THE FIRST DAY (SIR->*) I THINK + + The reference word `SIR` is missing in the predicted + results (a deletion error). + results: + An iterable of tuples. The first element is the cur_id, the second is + the reference transcript and the third element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. + Returns: + Return None. + """ + subs: Dict[Tuple[str, str], int] = defaultdict(int) + ins: Dict[str, int] = defaultdict(int) + dels: Dict[str, int] = defaultdict(int) + + # `words` stores counts per word, as follows: + # corr, ref_sub, hyp_sub, ins, dels + words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0]) + num_corr = 0 + ERR = "*" + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + for ref_word, hyp_word in ali: + if ref_word == ERR: + ins[hyp_word] += 1 + words[hyp_word][3] += 1 + elif hyp_word == ERR: + dels[ref_word] += 1 + words[ref_word][4] += 1 + elif hyp_word != ref_word: + subs[(ref_word, hyp_word)] += 1 + words[ref_word][1] += 1 + words[hyp_word][2] += 1 + else: + words[ref_word][0] += 1 + num_corr += 1 + ref_len = sum([len(r) for _, r, _ in results]) + sub_errs = sum(subs.values()) + ins_errs = sum(ins.values()) + del_errs = sum(dels.values()) + tot_errs = sub_errs + ins_errs + del_errs + tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + + if enable_log: + logging.info(f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]") + + print(f"%WER = {tot_err_rate}", file=f) + print( + f"Errors: {ins_errs} insertions, {del_errs} deletions, " + f"{sub_errs} substitutions, over {ref_len} reference " + f"words ({num_corr} correct)", + file=f, + ) + print( + "Search below for sections starting with PER-UTT DETAILS:, " + "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:", + file=f, + ) + + print("", file=f) + print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f) + for cut_id, ref, hyp in results: + ali = kaldialign.align(ref, hyp, ERR) + combine_successive_errors = True + if combine_successive_errors: + ali = [[[x], [y]] for x, y in ali] + for i in range(len(ali) - 1): + if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]: + ali[i + 1][0] = ali[i][0] + ali[i + 1][0] + ali[i + 1][1] = ali[i][1] + ali[i + 1][1] + ali[i] = [[], []] + ali = [[ + list(filter(lambda a: a != ERR, x)), + list(filter(lambda a: a != ERR, y)), + ] for x, y in ali] + ali = list(filter(lambda x: x != [[], []], ali)) + ali = [[ + ERR if x == [] else " ".join(x), + ERR if y == [] else " ".join(y), + ] for x, y in ali] + + print( + f"{cut_id}:\t" + " ".join((ref_word if ref_word == hyp_word else + f"({ref_word}->{hyp_word})" + for ref_word, hyp_word in ali)), + file=f, + ) + + print("", file=f) + print("SUBSTITUTIONS: count ref -> hyp", file=f) + + for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], + reverse=True): + print(f"{count} {ref} -> {hyp}", file=f) + + print("", file=f) + print("DELETIONS: count ref", file=f) + for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True): + print(f"{count} {ref}", file=f) + + print("", file=f) + print("INSERTIONS: count hyp", file=f) + for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True): + print(f"{count} {hyp}", file=f) + + print("", file=f) + print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", + file=f) + for _, word, counts in sorted([(sum(v[1:]), k, v) + for k, v in words.items()], + reverse=True): + (corr, ref_sub, hyp_sub, ins, dels) = counts + tot_errs = ref_sub + hyp_sub + ins + dels + ref_count = corr + ref_sub + dels + hyp_count = corr + hyp_sub + ins + + print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) + return float(tot_err_rate) \ No newline at end of file diff --git a/whisper_live/transcriber_tensorrt.py b/whisper_live/transcriber_tensorrt.py new file mode 100644 index 00000000..8634a8f5 --- /dev/null +++ b/whisper_live/transcriber_tensorrt.py @@ -0,0 +1,340 @@ +import argparse +import json +import re +import time +from collections import OrderedDict +from pathlib import Path +from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union + +import torch +import numpy as np +from whisper.tokenizer import get_tokenizer +from whisper_live.tensorrt_utils import (mel_filters, store_transcripts, + write_error_stats, load_audio_wav_format, + pad_or_trim) + +import tensorrt_llm +import tensorrt_llm.logger as logger +from tensorrt_llm._utils import (str_dtype_to_torch, str_dtype_to_trt, + trt_dtype_to_torch) +from tensorrt_llm.runtime import ModelConfig, SamplingConfig +from tensorrt_llm.runtime.session import Session, TensorInfo + + +SAMPLE_RATE = 16000 +N_FFT = 400 +HOP_LENGTH = 160 +CHUNK_LENGTH = 30 +N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk + + +class WhisperEncoding: + + def __init__(self, engine_dir): + self.session = self.get_session(engine_dir) + + def get_session(self, engine_dir): + config_path = engine_dir / 'encoder_config.json' + with open(config_path, 'r') as f: + config = json.load(f) + + use_gpt_attention_plugin = config['plugin_config'][ + 'gpt_attention_plugin'] + dtype = config['builder_config']['precision'] + n_mels = config['builder_config']['n_mels'] + num_languages = config['builder_config']['num_languages'] + + self.dtype = dtype + self.n_mels = n_mels + self.num_languages = num_languages + + serialize_path = engine_dir / f'whisper_encoder_{self.dtype}_tp1_rank0.engine' + + with open(serialize_path, 'rb') as f: + session = Session.from_serialized_engine(f.read()) + + return session + + def get_audio_features(self, mel): + inputs = OrderedDict() + output_list = [] + + inputs.update({'x': mel}) + output_list.append( + TensorInfo('x', str_dtype_to_trt(self.dtype), mel.shape)) + + output_info = (self.session).infer_shapes(output_list) + + logger.debug(f'output info {output_info}') + outputs = { + t.name: torch.empty(tuple(t.shape), + dtype=trt_dtype_to_torch(t.dtype), + device='cuda') + for t in output_info + } + stream = torch.cuda.current_stream() + ok = self.session.run(inputs=inputs, + outputs=outputs, + stream=stream.cuda_stream) + assert ok, 'Engine execution failed' + stream.synchronize() + audio_features = outputs['output'] + return audio_features + + +class WhisperDecoding: + + def __init__(self, engine_dir, runtime_mapping, debug_mode=False): + + self.decoder_config = self.get_config(engine_dir) + self.decoder_generation_session = self.get_session( + engine_dir, runtime_mapping, debug_mode) + + def get_config(self, engine_dir): + config_path = engine_dir / 'decoder_config.json' + with open(config_path, 'r') as f: + config = json.load(f) + decoder_config = OrderedDict() + decoder_config.update(config['plugin_config']) + decoder_config.update(config['builder_config']) + return decoder_config + + def get_session(self, engine_dir, runtime_mapping, debug_mode=False): + dtype = self.decoder_config['precision'] + serialize_path = engine_dir / f'whisper_decoder_{dtype}_tp1_rank0.engine' + with open(serialize_path, "rb") as f: + decoder_engine_buffer = f.read() + + decoder_model_config = ModelConfig( + num_heads=self.decoder_config['num_heads'], + num_kv_heads=self.decoder_config['num_heads'], + hidden_size=self.decoder_config['hidden_size'], + vocab_size=self.decoder_config['vocab_size'], + num_layers=self.decoder_config['num_layers'], + gpt_attention_plugin=self.decoder_config['gpt_attention_plugin'], + remove_input_padding=self.decoder_config['remove_input_padding'], + cross_attention=self.decoder_config['cross_attention'], + has_position_embedding=self. + decoder_config['has_position_embedding'], + has_token_type_embedding=self. + decoder_config['has_token_type_embedding'], + ) + decoder_generation_session = tensorrt_llm.runtime.GenerationSession( + decoder_model_config, + decoder_engine_buffer, + runtime_mapping, + debug_mode=debug_mode) + + return decoder_generation_session + + def generate(self, + decoder_input_ids, + encoder_outputs, + eot_id, + max_new_tokens=40, + num_beams=1): + encoder_input_lengths = torch.tensor( + [encoder_outputs.shape[1] for x in range(encoder_outputs.shape[0])], + dtype=torch.int32, + device='cuda') + + decoder_input_lengths = torch.tensor([ + decoder_input_ids.shape[-1] + for _ in range(decoder_input_ids.shape[0]) + ], + dtype=torch.int32, + device='cuda') + decoder_max_input_length = torch.max(decoder_input_lengths).item() + + # generation config + sampling_config = SamplingConfig(end_id=eot_id, + pad_id=eot_id, + num_beams=num_beams) + self.decoder_generation_session.setup( + decoder_input_lengths.size(0), + decoder_max_input_length, + max_new_tokens, + beam_width=num_beams, + encoder_max_input_length=encoder_outputs.shape[1]) + + torch.cuda.synchronize() + + decoder_input_ids = decoder_input_ids.type(torch.int32).cuda() + output_ids = self.decoder_generation_session.decode( + decoder_input_ids, + decoder_input_lengths, + sampling_config, + encoder_output=encoder_outputs, + encoder_input_lengths=encoder_input_lengths, + ) + torch.cuda.synchronize() + + # get the list of int from output_ids tensor + output_ids = output_ids.cpu().numpy().tolist() + return output_ids + + +class WhisperTRTLLM(object): + + def __init__( + self, + engine_dir, + debug_mode=False, + assets_dir=None, + device=None, + is_multilingual=False, + language="en", + task="transcribe" + ): + world_size = 1 + runtime_rank = tensorrt_llm.mpi_rank() + runtime_mapping = tensorrt_llm.Mapping(world_size, runtime_rank) + torch.cuda.set_device(runtime_rank % runtime_mapping.gpus_per_node) + engine_dir = Path(engine_dir) + + self.encoder = WhisperEncoding(engine_dir) + self.decoder = WhisperDecoding(engine_dir, + runtime_mapping, + debug_mode=False) + self.n_mels = self.encoder.n_mels + # self.tokenizer = get_tokenizer(num_languages=self.encoder.num_languages, + # tokenizer_dir=assets_dir) + self.device = device + self.tokenizer = get_tokenizer( + is_multilingual, + num_languages=self.encoder.num_languages, + language=language, + task=task, + ) + self.filters = mel_filters(self.device, self.encoder.n_mels, assets_dir) + + def log_mel_spectrogram( + self, + audio: Union[str, np.ndarray, torch.Tensor], + padding: int = 0, + return_duration = True + ): + """ + Compute the log-Mel spectrogram of + + Parameters + ---------- + audio: Union[str, np.ndarray, torch.Tensor], shape = (*) + The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz + + n_mels: int + The number of Mel-frequency filters, only 80 and 128 are supported + + padding: int + Number of zero samples to pad to the right + + device: Optional[Union[str, torch.device]] + If given, the audio tensor is moved to this device before STFT + + Returns + ------- + torch.Tensor, shape = (80 or 128, n_frames) + A Tensor that contains the Mel spectrogram + """ + if not torch.is_tensor(audio): + if isinstance(audio, str): + if audio.endswith('.wav'): + audio, _ = load_audio_wav_format(audio) + else: + audio = load_audio(audio) + assert isinstance(audio, + np.ndarray), f"Unsupported audio type: {type(audio)}" + duration = audio.shape[-1] / SAMPLE_RATE + audio = pad_or_trim(audio, N_SAMPLES) + audio = audio.astype(np.float32) + audio = torch.from_numpy(audio) + + if self.device is not None: + audio = audio.to(self.device) + if padding > 0: + audio = F.pad(audio, (0, padding)) + window = torch.hann_window(N_FFT).to(audio.device) + stft = torch.stft(audio, + N_FFT, + HOP_LENGTH, + window=window, + return_complex=True) + magnitudes = stft[..., :-1].abs()**2 + + + mel_spec = self.filters @ magnitudes + + log_spec = torch.clamp(mel_spec, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + log_spec = (log_spec + 4.0) / 4.0 + if return_duration: + return log_spec, duration + else: + return log_spec + + + def process_batch( + self, + mel, + text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + num_beams=1): + prompt_id = self.tokenizer.encode( + text_prefix, allowed_special=set(self.tokenizer.special_tokens.keys())) + + prompt_id = torch.tensor(prompt_id) + batch_size = mel.shape[0] + decoder_input_ids = prompt_id.repeat(batch_size, 1) + + encoder_output = self.encoder.get_audio_features(mel) + output_ids = self.decoder.generate(decoder_input_ids, + encoder_output, + self.tokenizer.eot, + max_new_tokens=96, + num_beams=num_beams) + texts = [] + for i in range(len(output_ids)): + text = self.tokenizer.decode(output_ids[i][0]).strip() + texts.append(text) + return texts + + def transcribe( + self, + mel, + text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + dtype='float16', + batch_size=1, + num_beams=1, + ): + mel = mel.type(str_dtype_to_torch(dtype)) + mel = mel.unsqueeze(0) + predictions = self.process_batch(mel, text_prefix, num_beams) + prediction = predictions[0] + + # remove all special tokens in the prediction + prediction = re.sub(r'<\|.*?\|>', '', prediction) + return prediction.strip() + + +def decode_wav_file( + model, + mel, + text_prefix="<|startoftranscript|><|en|><|transcribe|><|notimestamps|>", + dtype='float16', + batch_size=1, + num_beams=1, + normalizer=None, + mel_filters_dir=None): + + mel = mel.type(str_dtype_to_torch(dtype)) + mel = mel.unsqueeze(0) + # repeat the mel spectrogram to match the batch size + mel = mel.repeat(batch_size, 1, 1) + predictions = model.process_batch(mel, text_prefix, num_beams) + prediction = predictions[0] + + # remove all special tokens in the prediction + prediction = re.sub(r'<\|.*?\|>', '', prediction) + if normalizer: + prediction = normalizer(prediction) + + return prediction.strip() \ No newline at end of file diff --git a/whisper_live/vad.py b/whisper_live/vad.py new file mode 100644 index 00000000..53170f98 --- /dev/null +++ b/whisper_live/vad.py @@ -0,0 +1,118 @@ +# original: https://github.com/snakers4/silero-vad/blob/master/utils_vad.py + +import os +import subprocess +import torch +import numpy as np +import onnxruntime + + +class VoiceActivityDetection(): + + def __init__(self, force_onnx_cpu=True): + print("downloading ONNX model...") + path = self.download() + print("loading session") + + opts = onnxruntime.SessionOptions() + opts.log_severity_level = 3 + + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 + + print("loading onnx model") + if force_onnx_cpu and 'CPUExecutionProvider' in onnxruntime.get_available_providers(): + self.session = onnxruntime.InferenceSession(path, providers=['CPUExecutionProvider'], sess_options=opts) + else: + self.session = onnxruntime.InferenceSession(path, providers=['CUDAExecutionProvider'], sess_options=opts) + + print("reset states") + self.reset_states() + self.sample_rates = [8000, 16000] + + def _validate_input(self, x, sr: int): + if x.dim() == 1: + x = x.unsqueeze(0) + if x.dim() > 2: + raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}") + + if sr != 16000 and (sr % 16000 == 0): + step = sr // 16000 + x = x[:,::step] + sr = 16000 + + if sr not in self.sample_rates: + raise ValueError(f"Supported sampling rates: {self.sample_rates} (or multiply of 16000)") + + if sr / x.shape[1] > 31.25: + raise ValueError("Input audio chunk is too short") + + return x, sr + + def reset_states(self, batch_size=1): + self._h = np.zeros((2, batch_size, 64)).astype('float32') + self._c = np.zeros((2, batch_size, 64)).astype('float32') + self._last_sr = 0 + self._last_batch_size = 0 + + def __call__(self, x, sr: int): + + x, sr = self._validate_input(x, sr) + batch_size = x.shape[0] + + if not self._last_batch_size: + self.reset_states(batch_size) + if (self._last_sr) and (self._last_sr != sr): + self.reset_states(batch_size) + if (self._last_batch_size) and (self._last_batch_size != batch_size): + self.reset_states(batch_size) + + if sr in [8000, 16000]: + ort_inputs = {'input': x.numpy(), 'h': self._h, 'c': self._c, 'sr': np.array(sr, dtype='int64')} + ort_outs = self.session.run(None, ort_inputs) + out, self._h, self._c = ort_outs + else: + raise ValueError() + + self._last_sr = sr + self._last_batch_size = batch_size + + out = torch.tensor(out) + return out + + def audio_forward(self, x, sr: int, num_samples: int = 512): + outs = [] + x, sr = self._validate_input(x, sr) + + if x.shape[1] % num_samples: + pad_num = num_samples - (x.shape[1] % num_samples) + x = torch.nn.functional.pad(x, (0, pad_num), 'constant', value=0.0) + + self.reset_states(x.shape[0]) + for i in range(0, x.shape[1], num_samples): + wavs_batch = x[:, i:i+num_samples] + out_chunk = self.__call__(wavs_batch, sr) + outs.append(out_chunk) + + stacked = torch.cat(outs, dim=1) + return stacked.cpu() + + @staticmethod + def download(model_url="https://github.com/snakers4/silero-vad/raw/master/files/silero_vad.onnx"): + target_dir = os.path.expanduser("~/.cache/whisper-live/") + + # Ensure the target directory exists + os.makedirs(target_dir, exist_ok=True) + + # Define the target file path + model_filename = os.path.join(target_dir, "silero_vad.onnx") + + # Check if the model file already exists + if not os.path.exists(model_filename): + # If it doesn't exist, download the model using wget + print("Downloading VAD ONNX model...") + try: + subprocess.run(["wget", "-O", model_filename, model_url], check=True) + except subprocess.CalledProcessError: + print("Failed to download the model using wget.") + return model_filename \ No newline at end of file