diff --git a/src/wordcab_transcribe/config.py b/src/wordcab_transcribe/config.py index d91bfcc..608f422 100644 --- a/src/wordcab_transcribe/config.py +++ b/src/wordcab_transcribe/config.py @@ -20,11 +20,9 @@ """Configuration module of the Wordcab Transcribe.""" from os import getenv -from pathlib import Path from typing import Dict, List from dotenv import load_dotenv -from faster_whisper.utils import _MODELS from loguru import logger from pydantic import field_validator from pydantic.dataclasses import dataclass @@ -112,21 +110,6 @@ def api_prefix_must_not_be_none(cls, value: str): # noqa: B902, N805 return value - @field_validator("whisper_model") - def whisper_model_must_be_valid(cls, value: str): # noqa: B902, N805 - """Check that the model name is valid. It can be a local path or a model name.""" - model_path = Path(value) - - if model_path.exists() is False: - if value not in _MODELS.keys(): - raise ValueError( - f"{value} is not a valid model name. Choose one of" - f" {_MODELS.keys()}.If you want to use a local model, please" - " provide a valid path." - ) - - return value - @field_validator("compute_type") def compute_type_must_be_valid(cls, value: str): # noqa: B902, N805 """Check that the model precision is valid.""" diff --git a/src/wordcab_transcribe/services/diarization/diarize_service.py b/src/wordcab_transcribe/services/diarization/diarize_service.py index b85d6ab..f54b353 100644 --- a/src/wordcab_transcribe/services/diarization/diarize_service.py +++ b/src/wordcab_transcribe/services/diarization/diarize_service.py @@ -80,9 +80,7 @@ def __init__( else: self.default_segmentation_batch_size = 256 - self.default_scale_dict = { - k: (w, s) for k, (w, s) in enumerate(zip(window_lengths, shift_lengths)) - } + self.default_scale_dict = dict(enumerate(zip(window_lengths, shift_lengths))) for idx in device_index: _device = f"cuda:{idx}" if self.device == "cuda" else "cpu" @@ -127,22 +125,18 @@ def __call__( segmentation_batch_size = self.default_segmentation_batch_size multiscale_weights = self.default_multiscale_weights elif audio_duration < 10800: - scale_dict = { - k: (w, s) - for k, (w, s) in enumerate( + scale_dict = dict( + enumerate( zip( [3.0, 2.5, 2.0, 1.5, 1.0], self.default_shift_lengths, ) ) - } + ) segmentation_batch_size = 64 multiscale_weights = self.default_multiscale_weights else: - scale_dict = { - k: (w, s) - for k, (w, s) in enumerate(zip([3.0, 2.0, 1.0], [0.75, 0.5, 0.25])) - } + scale_dict = dict(enumerate(zip([3.0, 2.0, 1.0], [0.75, 0.5, 0.25]))) segmentation_batch_size = 32 multiscale_weights = [1.0, 1.0, 1.0] diff --git a/tests/test_config.py b/tests/test_config.py index 60a2cc9..9004106 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -126,19 +126,6 @@ def test_general_parameters_validator(default_settings: dict) -> None: Settings(**wrong_api_prefix) -def test_whisper_model_validator(default_settings: dict) -> None: - """Test whisper model validator.""" - wrong_whisper_model = default_settings.copy() - wrong_whisper_model["whisper_model"] = "invalid_model_name" - with pytest.raises(ValueError): - Settings(**wrong_whisper_model) - - wrong_whisper_model = default_settings.copy() - wrong_whisper_model["whisper_model"] = "/path/to/invalid_model" - with pytest.raises(ValueError): - Settings(**wrong_whisper_model) - - def test_compute_type_validator(default_settings: dict) -> None: """Test compute type validator.""" default_settings["compute_type"] = "invalid_compute_type"