Skip to content

Commit

Permalink
Further refinement of TensorRT-LLM backend based on WhisperS2T
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleks committed Mar 25, 2024
1 parent b4514bf commit dd954a9
Show file tree
Hide file tree
Showing 7 changed files with 561 additions and 56 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import hashlib
import os
import subprocess

import requests
from loguru import logger
from tqdm import tqdm

_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
Expand All @@ -18,8 +21,45 @@
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}

_TOKENIZERS = {
"tiny.en": (
"https://huggingface.co/Systran/faster-whisper-tiny.en/raw/main/tokenizer.json"
),
"tiny": (
"https://huggingface.co/Systran/faster-whisper-tiny/raw/main/tokenizer.json"
),
"small.en": (
"https://huggingface.co/Systran/faster-whisper-small.en/raw/main/tokenizer.json"
),
"small": (
"https://huggingface.co/Systran/faster-whisper-small/raw/main/tokenizer.json"
),
"base.en": (
"https://huggingface.co/Systran/faster-whisper-base.en/raw/main/tokenizer.json"
),
"base": (
"https://huggingface.co/Systran/faster-whisper-base/raw/main/tokenizer.json"
),
"medium.en": "https://huggingface.co/Systran/faster-whisper-medium.en/raw/main/tokenizer.json",
"medium": (
"https://huggingface.co/Systran/faster-whisper-medium/raw/main/tokenizer.json"
),
"large-v1": (
"https://huggingface.co/Systran/faster-whisper-large-v1/raw/main/tokenizer.json"
),
"large-v2": (
"https://huggingface.co/Systran/faster-whisper-large-v2/raw/main/tokenizer.json"
),
"large-v3": (
"https://huggingface.co/Systran/faster-whisper-large-v3/raw/main/tokenizer.json"
),
"large": (
"https://huggingface.co/Systran/faster-whisper-large-v3/raw/main/tokenizer.json"
),
}


def build_whisper_model(
def build_whisper_trt_model(
output_dir,
use_gpt_attention_plugin=True,
use_gemm_plugin=True,
Expand All @@ -44,41 +84,97 @@ def build_whisper_model(
None
"""
model_url = _MODELS[model_name]
model_path = f"assets/{model_name}.pt"

# Download the model if it doesn't exist
if not os.path.exists(model_path):
os.makedirs("assets", exist_ok=True)

print(f"Downloading model '{model_name}' from {model_url}...")
response = requests.get(model_url)

if response.status_code == 200:
with open(model_path, "wb") as file:
file.write(response.content)
print(f"Model '{model_name}' downloaded successfully.")
else:
print(
f"Failed to download model '{model_name}'. Status code:"
f" {response.status_code}"
)
return

command = ["python3", "build.py", "--output_dir", output_dir]

if use_gpt_attention_plugin:
command.append("--use_gpt_attention_plugin")
if use_gemm_plugin:
command.append("--use_gemm_plugin")
if use_bert_attention_plugin:
command.append("--use_bert_attention_plugin")
if enable_context_fmha:
command.append("--enable_context_fmha")
if use_weight_only:
command.append("--use_weight_only")

try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
print(f"Error occurred while building the model: {e}")
raise
expected_sha256 = model_url.split("/")[-2]
model_ckpt_path = f"../assets/{model_name}.pt"
tokenizer_path = f"{output_dir}/tokenizer.json"

if not os.path.exists(model_ckpt_path):
os.makedirs("../assets", exist_ok=True)

logger.info(f"Downloading model '{model_name}' from {model_url}...")

response = requests.get(model_url, stream=True)
total_size = int(response.headers.get("Content-Length", 0))

with open(model_ckpt_path, "wb") as output:
with tqdm(
total=total_size,
ncols=80,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
for data in response.iter_content(chunk_size=8192):
size = output.write(data)
pbar.update(size)

with open(model_ckpt_path, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not"
" match. Please retry loading the model."
)

if not os.path.exists(output_dir):
logger.info("Building the model...")
command = [
"python3",
"build.py",
"--output_dir",
output_dir,
"--model_name",
model_name,
]

if use_gpt_attention_plugin:
command.append("--use_gpt_attention_plugin")
if use_gemm_plugin:
command.append("--use_gemm_plugin")
if use_bert_attention_plugin:
command.append("--use_bert_attention_plugin")
if enable_context_fmha:
command.append("--enable_context_fmha")
if use_weight_only:
command.append("--use_weight_only")

try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
logger.error(f"Error occurred while building the model: {e}")
raise
logger.info("Model has been built successfully.")

if not os.path.exists(tokenizer_path):
logger.info(f"Downloading tokenizer for model '{model_name}'...")
response = requests.get(_TOKENIZERS[model_name], stream=True)
total_size = int(response.headers.get("Content-Length", 0))

with open(tokenizer_path, "wb") as output:
with tqdm(
total=total_size,
ncols=80,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as pbar:
for data in response.iter_content(chunk_size=8192):
size = output.write(data)
pbar.update(size)
logger.info("Tokenizer has been downloaded successfully.")

for filename in os.listdir(output_dir):
if "encoder" in filename and filename.endswith(".engine"):
new_filename = "encoder.engine"
old_path = os.path.join(output_dir, filename)
new_path = os.path.join(output_dir, new_filename)
os.rename(old_path, new_path)
logger.info(f"Renamed '{filename}' to '{new_filename}'")
elif "decoder" in filename and filename.endswith(".engine"):
new_filename = "decoder.engine"
old_path = os.path.join(output_dir, filename)
new_path = os.path.join(output_dir, new_filename)
os.rename(old_path, new_path)
logger.info(f"Renamed '{filename}' to '{new_filename}'")

return output_dir
24 changes: 13 additions & 11 deletions src/wordcab_transcribe/engines/tensorrt_llm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import ctranslate2
import numpy as np
import tokenizers

from wordcab_transcribe.engines.tensorrt_llm.engine_builder.create_trt_model import (
build_whisper_trt_model,
)
from wordcab_transcribe.engines.tensorrt_llm.hf_utils import download_model
from wordcab_transcribe.engines.tensorrt_llm.tokenizer import Tokenizer
from wordcab_transcribe.engines.tensorrt_llm.tokenizers import Tokenizer
from wordcab_transcribe.engines.tensorrt_llm.trt_model import WhisperTRT
from wordcab_transcribe.engines.tensorrt_llm.whisper_model import WhisperModel

Expand Down Expand Up @@ -77,11 +79,11 @@ def exact_div(x, y):


class WhisperModelTRT(WhisperModel):
"""TensorRT implementation of the Whisper model."""
"""TensorRT-LLM implementation of the Whisper model."""

def __init__(
self,
model_name_or_path: str,
model_name: str,
asr_options: dict,
cpu_threads=4,
num_workers=1,
Expand All @@ -91,20 +93,20 @@ def __init__(
max_text_token_len=15,
**model_kwargs
):
# ASR Options
self.asr_options = FAST_ASR_OPTIONS
self.asr_options.update(asr_options)
self.model_path = model_name_or_path
# # TODO build engine if not exists
self.model_name = model_name
self.model_path = os.path.join("models", self.model_name)

# Load model
if not os.path.exists(self.model_path):
self.model_path = build_whisper_trt_model(
self.model_path, model_name=self.model_name
)
self.model = WhisperTRT(self.model_path)

# Load tokenizer
# TODO: Have this downloaded as well
tokenizer_file = os.path.join(self.model_path, "tokenizer.json")
tokenizer = Tokenizer(
tokenizers.Tokenizer.from_file(tokenizer_file), self.model.is_multilingual
Tokenizer.from_file(tokenizer_file), self.model.is_multilingual
)

if self.asr_options["word_timestamps"]:
Expand Down
Loading

0 comments on commit dd954a9

Please sign in to comment.