Skip to content

Commit

Permalink
Fixing align model for n_mels=128 models
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleks committed Mar 30, 2024
1 parent f84e782 commit 6a0bdc4
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 20 deletions.
9 changes: 7 additions & 2 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ DEBUG=True
# The whisper_model parameter is used to control the model used for ASR.
#
# Cloud models:
# The available models are: tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, or large-v2
# The available models are: tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large, large-v1,
# large-v2, large-v3, distil-large-v2, or distil-large-v3.
# You can try different model size, but you should see a trade-off between performance and speed.
#
# Local models:
Expand All @@ -32,8 +33,12 @@ DEBUG=True
# e.g. WHISPER_MODEL="/app/models/custom"
# docker cmd: -v /path/to/custom/model:/app/models/custom
WHISPER_MODEL="distil-large-v2"
# You can specify one of two engines, "faster-whisper" or "tensorrt-llm".
# You can specify one of two engines, "faster-whisper" or "tensorrt-llm". At the moment, "faster-whisper" is more
# stable, adjustable, and accurate, while "tensorrt-llm" is faster but less accurate and adjustable.
WHISPER_ENGINE="tensorrt-llm"
# The align model is used for aligning timestamps under the "tensorrt-llm" engine. The available options are:
# "tiny", "small", "base", or "medium".
ALIGN_MODEL="tiny"
# The compute_type parameter is used to control the precision of the model. You can choose between:
# "int8", "int8_float16", "int8_bfloat16", "int16", "float_16", "bfloat16", "float32".
# The default value is "float16".
Expand Down
41 changes: 41 additions & 0 deletions src/wordcab_transcribe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Settings:
# Whisper
whisper_model: str
whisper_engine: str
align_model: str
compute_type: str
extra_languages: Union[List[str], None]
extra_languages_model_paths: Union[Dict[str, str], None]
Expand Down Expand Up @@ -86,6 +87,34 @@ def project_name_must_not_be_none(cls, value: str): # noqa: B902, N805

return value

@field_validator("whisper_model")
def whisper_model_compatibility_check(cls, value: str): # noqa: B902, N805
"""Check that the whisper engine is compatible."""
if value.lower() not in [
"tiny",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium",
"medium.en",
"large",
"large-v1",
"large-v2",
"large-v3",
"distil-large-v2",
"distil-large-v3",
]:
raise ValueError(
"The whisper models must be one of `tiny`, `tiny.en`, `base`,"
" `base.en`, `small`, `small.en`, `medium`, `medium.en`, `large`,"
" `large-v1`, `large-v2`, `large-v3`, `distil-large-v2`, or"
" `distil-large-v3`."
)

return value

@field_validator("whisper_engine")
def whisper_engine_compatibility_check(cls, value: str): # noqa: B902, N805
"""Check that the whisper engine is compatible."""
Expand All @@ -96,6 +125,17 @@ def whisper_engine_compatibility_check(cls, value: str): # noqa: B902, N805

return value

@field_validator("align_model")
def align_model_compatibility_check(cls, value: str): # noqa: B902, N805
"""Check that the whisper engine is compatible."""
if value.lower() not in ["tiny", "small", "base", "medium"]:
raise ValueError(
"The whisper engine must be one of `tiny`, `small`, `base`, or"
" `medium`."
)

return value

@field_validator("version")
def version_must_not_be_none(cls, value: str): # noqa: B902, N805
"""Check that the version is not None."""
Expand Down Expand Up @@ -264,6 +304,7 @@ def __post_init__(self):
# Transcription
whisper_model=getenv("WHISPER_MODEL", "distil-large-v2"),
whisper_engine=getenv("WHISPER_ENGINE", "tensorrt-llm"),
align_model=getenv("ALIGN_MODEL", "tiny"),
compute_type=getenv("COMPUTE_TYPE", "float16"),
extra_languages=extra_languages,
extra_languages_model_paths=extra_languages_model_paths,
Expand Down
39 changes: 26 additions & 13 deletions src/wordcab_transcribe/engines/tensorrt_llm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def exact_div(x, y):
"sampling_temperature": 1.0,
"return_scores": True,
"return_no_speech_prob": True,
"word_aligner_model": "tiny",
"word_align_model": "tiny",
}

BEST_ASR_CONFIG = {
Expand All @@ -75,7 +75,7 @@ def exact_div(x, y):
"sampling_temperature": 1.0,
"return_scores": True,
"return_no_speech_prob": True,
"word_aligner_model": "tiny",
"word_align_model": "tiny",
}


Expand All @@ -98,6 +98,8 @@ def __init__(
self.asr_options = BEST_ASR_CONFIG
elif asr_options in ["fast", "default"]:
self.asr_options = FAST_ASR_OPTIONS
else:
self.asr_options = FAST_ASR_OPTIONS

if isinstance(asr_options, dict):
self.asr_options.update(asr_options)
Expand All @@ -123,17 +125,17 @@ def __init__(
)

if self.asr_options["word_timestamps"]:
aligner_model = self.model_dir / self.asr_options["word_aligner_model"]
if not aligner_model.exists():
self.aligner_model_path = download_model(
self.asr_options["word_aligner_model"],
output_dir=aligner_model,
align_model = self.model_dir / self.asr_options["word_align_model"]
if not align_model.exists():
self.align_model_path = download_model(
self.asr_options["word_align_model"],
output_dir=align_model,
)
else:
self.aligner_model_path = aligner_model
self.align_model_path = align_model

self.aligner_model = ctranslate2.models.Whisper(
str(self.aligner_model_path),
self.align_model = ctranslate2.models.Whisper(
str(self.align_model_path),
device=device,
device_index=device_index,
compute_type=compute_type,
Expand Down Expand Up @@ -223,7 +225,7 @@ def align_words(

token_alignments = [[] for _ in seg_metadata]
for start_seq, req_idx in start_seq_wise_req.items():
res = self.aligner_model.align(
res = self.align_model.align(
ctranslate2.StorageView.from_array(features[req_idx]),
start_sequence=list(start_seq),
text_tokens=[text_tokens[_] for _ in req_idx],
Expand Down Expand Up @@ -264,7 +266,13 @@ def align_words(
return word_timings

def generate_segment_batched(
self, features, prompts, seq_lens, seg_metadata, generate_kwargs=None
self,
features,
prompts,
seg_metadata,
align_features,
align_seq_lens,
generate_kwargs=None,
):
if generate_kwargs is not None:
self.update_generation_kwargs(generate_kwargs)
Expand All @@ -281,7 +289,12 @@ def generate_segment_batched(
]
sot_seqs = [tuple(_[-4:]) for _ in prompts]
word_timings = self.align_words(
features, texts, text_tokens, sot_seqs, seq_lens, seg_metadata
align_features,
texts,
text_tokens,
sot_seqs,
align_seq_lens,
seg_metadata,
)

for _response, _word_timings in zip(response, word_timings):
Expand Down
17 changes: 12 additions & 5 deletions src/wordcab_transcribe/engines/tensorrt_llm/whisper_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,13 @@ def __init__(
tokenizer = NoneTokenizer()

self.tokenizer = tokenizer
self._init_dependables()
self._init_dependencies()

def _init_dependables(self):
def _init_dependencies(self):
self.dta_padding = int(self.dta_padding * 16000)
self.max_initial_prompt_len = self.max_text_token_len // 2 - 1
self.preprocessor = LogMelSpectogram(n_mels=self.n_mels).to(self.device)
self.align_preprocessor = LogMelSpectogram(n_mels=80).to(self.device)
self.speech_segmenter = SpeechSegmenter(
self.vad_model, device=self.device, **self.speech_segmenter_options
)
Expand All @@ -121,7 +122,7 @@ def update_params(self, params: dict):
for key, value in params.items():
setattr(self, key, value)

self._init_dependables()
self._init_dependencies()

@abstractmethod
def generate_segment_batched(self, features, prompts):
Expand Down Expand Up @@ -150,9 +151,15 @@ def transcribe(
batch_size=batch_size,
use_vad=use_vad,
):
mels, seq_lens = self.preprocessor(signals, seq_lens)
mels, _ = self.preprocessor(signals, seq_lens)
align_mels, align_seq_lens = self.align_preprocessor(signals, seq_lens)
res = self.generate_segment_batched(
mels.to(self.device), prompts, seq_lens, seg_metadata, generate_kwargs
mels.to(self.device),
prompts,
seg_metadata,
align_mels.to(self.device),
align_seq_lens,
generate_kwargs,
)
for res_idx, _seg_metadata in enumerate(seg_metadata):
responses[_seg_metadata["file_id"]].append(
Expand Down
4 changes: 4 additions & 0 deletions src/wordcab_transcribe/services/transcribe_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from loguru import logger
from tensorshare import Backend, TensorShare

from wordcab_transcribe.config import settings
from wordcab_transcribe.engines.tensorrt_llm.model import WhisperModelTRT
from wordcab_transcribe.models import (
MultiChannelSegment,
Expand Down Expand Up @@ -93,6 +94,9 @@ def __init__(
device=self.device,
device_index=device_index,
compute_type=self.compute_type,
asr_options={
"word_align_model": settings.align_model,
},
)
else:
self.model = WhisperModel(
Expand Down

0 comments on commit 6a0bdc4

Please sign in to comment.