Skip to content

Commit

Permalink
Let faster-whisper handle model path (#251)
Browse files Browse the repository at this point in the history
* let faster-whisper handle model path

* fix quality ruff warnings

* remove useless tests
  • Loading branch information
Thomas Chaigneau authored Sep 23, 2023
1 parent b117d53 commit dfce3fc
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 41 deletions.
17 changes: 0 additions & 17 deletions src/wordcab_transcribe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@
"""Configuration module of the Wordcab Transcribe."""

from os import getenv
from pathlib import Path
from typing import Dict, List

from dotenv import load_dotenv
from faster_whisper.utils import _MODELS
from loguru import logger
from pydantic import field_validator
from pydantic.dataclasses import dataclass
Expand Down Expand Up @@ -112,21 +110,6 @@ def api_prefix_must_not_be_none(cls, value: str): # noqa: B902, N805

return value

@field_validator("whisper_model")
def whisper_model_must_be_valid(cls, value: str): # noqa: B902, N805
"""Check that the model name is valid. It can be a local path or a model name."""
model_path = Path(value)

if model_path.exists() is False:
if value not in _MODELS.keys():
raise ValueError(
f"{value} is not a valid model name. Choose one of"
f" {_MODELS.keys()}.If you want to use a local model, please"
" provide a valid path."
)

return value

@field_validator("compute_type")
def compute_type_must_be_valid(cls, value: str): # noqa: B902, N805
"""Check that the model precision is valid."""
Expand Down
16 changes: 5 additions & 11 deletions src/wordcab_transcribe/services/diarization/diarize_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,7 @@ def __init__(
else:
self.default_segmentation_batch_size = 256

self.default_scale_dict = {
k: (w, s) for k, (w, s) in enumerate(zip(window_lengths, shift_lengths))
}
self.default_scale_dict = dict(enumerate(zip(window_lengths, shift_lengths)))

for idx in device_index:
_device = f"cuda:{idx}" if self.device == "cuda" else "cpu"
Expand Down Expand Up @@ -127,22 +125,18 @@ def __call__(
segmentation_batch_size = self.default_segmentation_batch_size
multiscale_weights = self.default_multiscale_weights
elif audio_duration < 10800:
scale_dict = {
k: (w, s)
for k, (w, s) in enumerate(
scale_dict = dict(
enumerate(
zip(
[3.0, 2.5, 2.0, 1.5, 1.0],
self.default_shift_lengths,
)
)
}
)
segmentation_batch_size = 64
multiscale_weights = self.default_multiscale_weights
else:
scale_dict = {
k: (w, s)
for k, (w, s) in enumerate(zip([3.0, 2.0, 1.0], [0.75, 0.5, 0.25]))
}
scale_dict = dict(enumerate(zip([3.0, 2.0, 1.0], [0.75, 0.5, 0.25])))
segmentation_batch_size = 32
multiscale_weights = [1.0, 1.0, 1.0]

Expand Down
13 changes: 0 additions & 13 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,19 +126,6 @@ def test_general_parameters_validator(default_settings: dict) -> None:
Settings(**wrong_api_prefix)


def test_whisper_model_validator(default_settings: dict) -> None:
"""Test whisper model validator."""
wrong_whisper_model = default_settings.copy()
wrong_whisper_model["whisper_model"] = "invalid_model_name"
with pytest.raises(ValueError):
Settings(**wrong_whisper_model)

wrong_whisper_model = default_settings.copy()
wrong_whisper_model["whisper_model"] = "/path/to/invalid_model"
with pytest.raises(ValueError):
Settings(**wrong_whisper_model)


def test_compute_type_validator(default_settings: dict) -> None:
"""Test compute type validator."""
default_settings["compute_type"] = "invalid_compute_type"
Expand Down

0 comments on commit dfce3fc

Please sign in to comment.