diff --git a/.env b/.env index 6c1e275..96871f3 100644 --- a/.env +++ b/.env @@ -8,7 +8,7 @@ # The name of the project, used for API documentation. PROJECT_NAME="Wordcab Transcribe" # The version of the project, used for API documentation. -VERSION="0.5.0" +VERSION="0.5.1" # The description of the project, used for API documentation. DESCRIPTION="💬 ASR FastAPI server using faster-whisper and Auto-Tuning Spectral Clustering for diarization." # This API prefix is used for all endpoints in the API outside of the status and cortex endpoints. @@ -60,6 +60,10 @@ MULTISCALE_WEIGHTS="1.0,1.0,1.0,1.0,1.0" # files are processed. # * `live` is the option to use when you want to process a live audio stream. It will process the audio in chunks, # and return the results as soon as they are available. Live option is still a feature in development. +# * `only_transcription` is used to deploy a single transcription server. +# This option is used when you want to deploy each service in a separate server. +# * `only_diarization` is used to deploy a single diarization server. +# This option is used when you want to deploy each service in a separate server. # Use `live` only if you need live results, otherwise, use `async`. ASR_TYPE="async" # @@ -106,4 +110,17 @@ SVIX_API_KEY= # The svix_app_id parameter is used in the cortex implementation to enable webhooks. SVIX_APP_ID= # +# -------------------------------------------------- REMOTE SERVERS -------------------------------------------------- # +# The remote servers configuration is used to control the number of servers used to process the requests if you don't +# want to group all the services in one server. +# +# The TRANSCRIBE_SERVER_URLS parameter is used to control the URLs of the servers used to process the requests. +# Each url should be separated by a comma and have this format: "host:port". +# e.g. SERVER_URLS="http://1.2.3.4:8000,http://4.3.2.1:8000" +TRANSCRIBE_SERVER_URLS= +# The DIARIZE_SERVER_URLS parameter is used to control the URLs of the servers used to process the requests. +# Each url should be separated by a comma and have this format: "host:port". +# e.g. SERVER_URLS="http://1.2.3.4:8000,http://4.3.2.1:8000" +DIARIZE_SERVER_URLS= +# # -------------------------------------------------------------------------------------------------------------------- # diff --git a/src/wordcab_transcribe/config.py b/src/wordcab_transcribe/config.py index 608f422..a4e0e0e 100644 --- a/src/wordcab_transcribe/config.py +++ b/src/wordcab_transcribe/config.py @@ -20,12 +20,13 @@ """Configuration module of the Wordcab Transcribe.""" from os import getenv -from typing import Dict, List +from typing import Dict, List, Union from dotenv import load_dotenv from loguru import logger from pydantic import field_validator from pydantic.dataclasses import dataclass +from typing_extensions import Literal from wordcab_transcribe import __version__ @@ -44,14 +45,14 @@ class Settings: # Whisper whisper_model: str compute_type: str - extra_languages: List[str] - extra_languages_model_paths: Dict[str, str] + extra_languages: Union[List[str], None] + extra_languages_model_paths: Union[Dict[str, str], None] # Diarization window_lengths: List[float] shift_lengths: List[float] multiscale_weights: List[float] # ASR type configuration - asr_type: str + asr_type: Literal["async", "live", "only_transcription", "only_diarization"] # Endpoints configuration audio_file_endpoint: bool audio_url_endpoint: bool @@ -69,6 +70,9 @@ class Settings: # Svix configuration svix_api_key: str svix_app_id: str + # Remote servers configuration + transcribe_server_urls: Union[List[str], None] + diarize_server_urls: Union[List[str], None] @field_validator("project_name") def project_name_must_not_be_none(cls, value: str): # noqa: B902, N805 @@ -213,29 +217,46 @@ def __post_init__(self): # Extra languages _extra_languages = getenv("EXTRA_LANGUAGES", None) if _extra_languages is not None and _extra_languages != "": - extra_languages = _extra_languages.split(",") + extra_languages = [lang.strip() for lang in _extra_languages.split(",")] else: - extra_languages = [] + extra_languages = None + +extra_languages_model_paths = ( + {lang: "" for lang in extra_languages} if extra_languages is not None else None +) # Diarization scales _window_lengths = getenv("WINDOW_LENGTHS", None) if _window_lengths is not None: - window_lengths = [float(x) for x in _window_lengths.split(",")] + window_lengths = [float(x.strip()) for x in _window_lengths.split(",")] else: window_lengths = [1.5, 1.25, 1.0, 0.75, 0.5] _shift_lengths = getenv("SHIFT_LENGTHS", None) if _shift_lengths is not None: - shift_lengths = [float(x) for x in _shift_lengths.split(",")] + shift_lengths = [float(x.strip()) for x in _shift_lengths.split(",")] else: shift_lengths = [0.75, 0.625, 0.5, 0.375, 0.25] _multiscale_weights = getenv("MULTISCALE_WEIGHTS", None) if _multiscale_weights is not None: - multiscale_weights = [float(x) for x in _multiscale_weights.split(",")] + multiscale_weights = [float(x.strip()) for x in _multiscale_weights.split(",")] else: multiscale_weights = [1.0, 1.0, 1.0, 1.0, 1.0] +# Multi-servers configuration +_transcribe_server_urls = getenv("TRANSCRIBE_SERVER_URLS", None) +if _transcribe_server_urls is not None and _transcribe_server_urls != "": + transcribe_server_urls = [url.strip() for url in _transcribe_server_urls.split(",")] +else: + transcribe_server_urls = None + +_diarize_server_urls = getenv("DIARIZE_SERVER_URLS", None) +if _diarize_server_urls is not None and _diarize_server_urls != "": + diarize_server_urls = [url.strip() for url in _diarize_server_urls.split(",")] +else: + diarize_server_urls = None + settings = Settings( # General configuration project_name=getenv("PROJECT_NAME", "Wordcab Transcribe"), @@ -252,7 +273,7 @@ def __post_init__(self): whisper_model=getenv("WHISPER_MODEL", "large-v2"), compute_type=getenv("COMPUTE_TYPE", "float16"), extra_languages=extra_languages, - extra_languages_model_paths={lang: "" for lang in extra_languages}, + extra_languages_model_paths=extra_languages_model_paths, # Diarization window_lengths=window_lengths, shift_lengths=shift_lengths, @@ -276,4 +297,7 @@ def __post_init__(self): # Svix configuration svix_api_key=getenv("SVIX_API_KEY", ""), svix_app_id=getenv("SVIX_APP_ID", ""), + # Remote servers configuration + transcribe_server_urls=transcribe_server_urls, + diarize_server_urls=diarize_server_urls, ) diff --git a/src/wordcab_transcribe/dependencies.py b/src/wordcab_transcribe/dependencies.py index d637598..f99ef23 100644 --- a/src/wordcab_transcribe/dependencies.py +++ b/src/wordcab_transcribe/dependencies.py @@ -52,8 +52,14 @@ multiscale_weights=settings.multiscale_weights, extra_languages=settings.extra_languages, extra_languages_model_paths=settings.extra_languages_model_paths, + transcribe_server_urls=settings.transcribe_server_urls, + diarize_server_urls=settings.diarize_server_urls, debug_mode=settings.debug, ) +elif settings.asr_type == "only_transcription": + asr = None +elif settings.asr_type == "only_diarization": + asr = None else: raise ValueError(f"Invalid ASR type: {settings.asr_type}") @@ -69,28 +75,29 @@ async def lifespan(app: FastAPI) -> None: " https://github.com/Wordcab/wordcab-transcribe/issues" ) - if check_ffmpeg() is False: - logger.warning( - "FFmpeg is not installed on the host machine.\n" - "Please install it and try again: `sudo apt-get install ffmpeg`" - ) - exit(1) + if settings.asr_type == "async" or settings.asr_type == "remote_transcribe": + if check_ffmpeg() is False: + logger.warning( + "FFmpeg is not installed on the host machine.\n" + "Please install it and try again: `sudo apt-get install ffmpeg`" + ) + exit(1) - if settings.extra_languages: - logger.info("Downloading models for extra languages...") - for model in settings.extra_languages: - try: - model_path = download_model( - compute_type=settings.compute_type, language=model - ) + if settings.extra_languages is not None: + logger.info("Downloading models for extra languages...") + for model in settings.extra_languages: + try: + model_path = download_model( + compute_type=settings.compute_type, language=model + ) - if model_path is not None: - settings.extra_languages_model_paths[model] = model_path - else: - raise Exception(f"Coudn't download model for {model}") + if model_path is not None: + settings.extra_languages_model_paths[model] = model_path + else: + raise Exception(f"Coudn't download model for {model}") - except Exception as e: - logger.error(f"Error downloading model for {model}: {e}") + except Exception as e: + logger.error(f"Error downloading model for {model}: {e}") logger.info("Warmup initialization...") await asr.inference_warmup() diff --git a/src/wordcab_transcribe/main.py b/src/wordcab_transcribe/main.py index b8eab6c..4acdbb5 100644 --- a/src/wordcab_transcribe/main.py +++ b/src/wordcab_transcribe/main.py @@ -55,7 +55,7 @@ else: app.include_router(api_router, prefix=settings.api_prefix) -if settings.cortex_endpoint: +if settings.cortex_endpoint and settings.asr_type == "async": app.include_router(cortex_router, tags=["cortex"]) diff --git a/src/wordcab_transcribe/models.py b/src/wordcab_transcribe/models.py index 778b5fb..32a9b68 100644 --- a/src/wordcab_transcribe/models.py +++ b/src/wordcab_transcribe/models.py @@ -28,10 +28,10 @@ class ProcessTimes(BaseModel): """The execution times of the different processes.""" - total: float - transcription: float - diarization: Union[float, None] - post_processing: float + total: Union[float, None] = None + transcription: Union[float, None] = None + diarization: Union[float, None] = None + post_processing: Union[float, None] = None class Timestamps(str, Enum): @@ -72,7 +72,7 @@ class BaseResponse(BaseModel): diarization: bool source_lang: str timestamps: str - vocab: List[str] + vocab: Union[List[str], None] word_timestamps: bool internal_vad: bool repetition_penalty: float @@ -219,7 +219,7 @@ class CortexPayload(BaseModel): multi_channel: Optional[bool] = False source_lang: Optional[str] = "en" timestamps: Optional[Timestamps] = Timestamps.seconds - vocab: Optional[List[str]] = [] + vocab: Union[List[str], None] = None word_timestamps: Optional[bool] = False internal_vad: Optional[bool] = False repetition_penalty: Optional[float] = 1.2 @@ -386,7 +386,7 @@ class BaseRequest(BaseModel): diarization: bool = False source_lang: str = "en" timestamps: Timestamps = Timestamps.seconds - vocab: List[str] = [] + vocab: Union[List[str], None] = None word_timestamps: bool = False internal_vad: bool = False repetition_penalty: float = 1.2 @@ -397,10 +397,12 @@ class BaseRequest(BaseModel): @field_validator("vocab") def validate_each_vocab_value( - cls, value: List[str] # noqa: B902, N805 + cls, value: Union[List[str], None] # noqa: B902, N805 ) -> List[str]: """Validate the value of each vocab field.""" - if not all(isinstance(v, str) for v in value): + if value == []: + return None + elif value is not None and not all(isinstance(v, str) for v in value): raise ValueError("`vocab` must be a list of strings.") return value @@ -480,6 +482,14 @@ class Config: } +class DiarizeResponse(BaseModel): + """Response model for the diarize endpoint.""" + + +class TranscribeResponse(BaseModel): + """Response model for the transcribe endpoint.""" + + class Token(BaseModel): """Token model for authentication.""" diff --git a/src/wordcab_transcribe/pydantic_annotations.py b/src/wordcab_transcribe/pydantic_annotations.py new file mode 100644 index 0000000..c881a08 --- /dev/null +++ b/src/wordcab_transcribe/pydantic_annotations.py @@ -0,0 +1,72 @@ +# Copyright 2023 The Wordcab Team. All rights reserved. +# +# Licensed under the Wordcab Transcribe License 0.1 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Wordcab/wordcab-transcribe/blob/main/LICENSE +# +# Except as expressly provided otherwise herein, and to the fullest +# extent permitted by law, Licensor provides the Software (and each +# Contributor provides its Contributions) AS IS, and Licensor +# disclaims all warranties or guarantees of any kind, express or +# implied, whether arising under any law or from any usage in trade, +# or otherwise including but not limited to the implied warranties +# of merchantability, non-infringement, quiet enjoyment, fitness +# for a particular purpose, or otherwise. +# +# See the License for the specific language governing permissions +# and limitations under the License. +"""Custom Pydantic annotations for the Wordcab Transcribe API.""" + +from typing import Any + +import torch +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema + + +class TorchTensorPydanticAnnotation: + """Pydantic annotation for torch.Tensor.""" + + @classmethod + def __get_pydantic_core_schema__( + cls, + _source_type: Any, + _handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + """Custom validation and serialization for torch.Tensor.""" + + def validate_tensor(value) -> torch.Tensor: + if not isinstance(value, torch.Tensor): + raise ValueError(f"Expected a torch.Tensor but got {type(value)}") + return value + + return core_schema.chain_schema( + [ + core_schema.no_info_plain_validator_function(validate_tensor), + ] + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + # This is a custom representation for the tensor in JSON Schema. + # Here, it represents a tensor as an object with metadata. + return { + "type": "object", + "properties": { + "dtype": { + "type": "string", + "description": "Data type of the tensor", + }, + "shape": { + "type": "array", + "items": {"type": "integer"}, + "description": "Shape of the tensor", + }, + }, + "required": ["dtype", "shape"], + } diff --git a/src/wordcab_transcribe/router/v1/audio_file_endpoint.py b/src/wordcab_transcribe/router/v1/audio_file_endpoint.py index c9b5bbc..f0b01bb 100644 --- a/src/wordcab_transcribe/router/v1/audio_file_endpoint.py +++ b/src/wordcab_transcribe/router/v1/audio_file_endpoint.py @@ -29,6 +29,7 @@ from wordcab_transcribe.dependencies import asr from wordcab_transcribe.models import AudioRequest, AudioResponse +from wordcab_transcribe.services.asr_service import ProcessException from wordcab_transcribe.utils import ( check_num_channels, delete_file, @@ -44,14 +45,14 @@ ) async def inference_with_audio( # noqa: C901 background_tasks: BackgroundTasks, - offset_start: float = Form(None), # noqa: B008 - offset_end: float = Form(None), # noqa: B008 + offset_start: Union[float, None] = Form(None), # noqa: B008 + offset_end: Union[float, None] = Form(None), # noqa: B008 num_speakers: int = Form(-1), # noqa: B008 diarization: bool = Form(False), # noqa: B008 multi_channel: bool = Form(False), # noqa: B008 source_lang: str = Form("en"), # noqa: B008 timestamps: str = Form("s"), # noqa: B008 - vocab: List[str] = Form([]), # noqa: B008 + vocab: Union[List[str], None] = Form(None), # noqa: B008 word_timestamps: bool = Form(False), # noqa: B008 internal_vad: bool = Form(False), # noqa: B008 repetition_penalty: float = Form(1.2), # noqa: B008 @@ -131,11 +132,11 @@ async def inference_with_audio( # noqa: C901 background_tasks.add_task(delete_file, filepath=filepath) - if isinstance(result, Exception): - logger.error(f"Error: {result}") + if isinstance(result, ProcessException): + logger.error(result.message) raise HTTPException( status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(result), + detail=str(result.message), ) else: utterances, process_times, audio_duration = result diff --git a/src/wordcab_transcribe/router/v1/audio_url_endpoint.py b/src/wordcab_transcribe/router/v1/audio_url_endpoint.py index 9b63ac4..ecc350e 100644 --- a/src/wordcab_transcribe/router/v1/audio_url_endpoint.py +++ b/src/wordcab_transcribe/router/v1/audio_url_endpoint.py @@ -29,6 +29,7 @@ from wordcab_transcribe.dependencies import asr, download_limit from wordcab_transcribe.models import AudioRequest, AudioResponse +from wordcab_transcribe.services.asr_service import ProcessException from wordcab_transcribe.utils import ( check_num_channels, delete_file, @@ -94,11 +95,11 @@ async def inference_with_audio_url( background_tasks.add_task(delete_file, filepath=filepath) - if isinstance(result, Exception): - logger.error(f"Error: {result}") + if isinstance(result, ProcessException): + logger.error(result.message) raise HTTPException( status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(result), + detail=str(result.message), ) else: utterances, process_times, audio_duration = result diff --git a/src/wordcab_transcribe/router/v1/diarize_endpoint.py b/src/wordcab_transcribe/router/v1/diarize_endpoint.py new file mode 100644 index 0000000..d42d646 --- /dev/null +++ b/src/wordcab_transcribe/router/v1/diarize_endpoint.py @@ -0,0 +1,38 @@ +# Copyright 2023 The Wordcab Team. All rights reserved. +# +# Licensed under the Wordcab Transcribe License 0.1 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Wordcab/wordcab-transcribe/blob/main/LICENSE +# +# Except as expressly provided otherwise herein, and to the fullest +# extent permitted by law, Licensor provides the Software (and each +# Contributor provides its Contributions) AS IS, and Licensor +# disclaims all warranties or guarantees of any kind, express or +# implied, whether arising under any law or from any usage in trade, +# or otherwise including but not limited to the implied warranties +# of merchantability, non-infringement, quiet enjoyment, fitness +# for a particular purpose, or otherwise. +# +# See the License for the specific language governing permissions +# and limitations under the License. +"""Diarize endpoint for the Remote Wordcab Transcribe API.""" + +from typing import Union + +from fastapi import APIRouter +from fastapi import status as http_status + +from wordcab_transcribe.models import DiarizeResponse + +router = APIRouter() + + +@router.post( + "", + response_model=Union[DiarizeResponse, str], + status_code=http_status.HTTP_200_OK, +) +async def remote_diarization() -> DiarizeResponse: + """Diarize endpoint for the Remote Wordcab Transcribe API.""" diff --git a/src/wordcab_transcribe/router/v1/endpoints.py b/src/wordcab_transcribe/router/v1/endpoints.py index 492cde2..6f5eee9 100644 --- a/src/wordcab_transcribe/router/v1/endpoints.py +++ b/src/wordcab_transcribe/router/v1/endpoints.py @@ -28,7 +28,9 @@ from wordcab_transcribe.router.v1.cortex_endpoint import ( # noqa: F401 router as cortex_router, ) +from wordcab_transcribe.router.v1.diarize_endpoint import router as diarize_router from wordcab_transcribe.router.v1.live_endpoint import router as live_router +from wordcab_transcribe.router.v1.transcribe_endpoint import router as transcribe_router from wordcab_transcribe.router.v1.youtube_endpoint import router as youtube_router api_router = APIRouter() @@ -38,12 +40,28 @@ ("audio_url_endpoint", audio_url_router, "/audio-url", "async"), ("youtube_endpoint", youtube_router, "/youtube", "async"), ) -live_routers = (("live_endpoint", live_router, "/live", "live"),) +live_routers = ("live_endpoint", live_router, "/live", "live") +transcribe_routers = ( + "transcribe_endpoint", + transcribe_router, + "/transcribe", + "transcription", +) +diarize_routers = ( + "diariaze_endpoint", + diarize_router, + "/diarize", + "diarization", +) if settings.asr_type == "async": routers = async_routers elif settings.asr_type == "live": routers = live_routers +elif settings.asr_type == "only_transcription": + routers = transcribe_routers +elif settings.asr_type == "only_diarization": + routers = diarize_routers else: raise ValueError(f"Invalid ASR type: {settings.asr_type}") diff --git a/src/wordcab_transcribe/router/v1/transcribe_endpoint.py b/src/wordcab_transcribe/router/v1/transcribe_endpoint.py new file mode 100644 index 0000000..e1002c8 --- /dev/null +++ b/src/wordcab_transcribe/router/v1/transcribe_endpoint.py @@ -0,0 +1,38 @@ +# Copyright 2023 The Wordcab Team. All rights reserved. +# +# Licensed under the Wordcab Transcribe License 0.1 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Wordcab/wordcab-transcribe/blob/main/LICENSE +# +# Except as expressly provided otherwise herein, and to the fullest +# extent permitted by law, Licensor provides the Software (and each +# Contributor provides its Contributions) AS IS, and Licensor +# disclaims all warranties or guarantees of any kind, express or +# implied, whether arising under any law or from any usage in trade, +# or otherwise including but not limited to the implied warranties +# of merchantability, non-infringement, quiet enjoyment, fitness +# for a particular purpose, or otherwise. +# +# See the License for the specific language governing permissions +# and limitations under the License. +"""Transcribe endpoint for the Remote Wordcab Transcribe API.""" + +from typing import Union + +from fastapi import APIRouter +from fastapi import status as http_status + +from wordcab_transcribe.models import TranscribeResponse + +router = APIRouter() + + +@router.post( + "", + response_model=Union[TranscribeResponse, str], + status_code=http_status.HTTP_200_OK, +) +async def remote_transcription() -> TranscribeResponse: + """Transcribe endpoint for the Remote Wordcab Transcribe API.""" diff --git a/src/wordcab_transcribe/router/v1/youtube_endpoint.py b/src/wordcab_transcribe/router/v1/youtube_endpoint.py index db1cfb5..2a03612 100644 --- a/src/wordcab_transcribe/router/v1/youtube_endpoint.py +++ b/src/wordcab_transcribe/router/v1/youtube_endpoint.py @@ -29,6 +29,7 @@ from wordcab_transcribe.dependencies import asr, download_limit from wordcab_transcribe.models import BaseRequest, YouTubeResponse +from wordcab_transcribe.services.asr_service import ProcessException from wordcab_transcribe.utils import delete_file, download_audio_file router = APIRouter() @@ -72,11 +73,11 @@ async def inference_with_youtube( background_tasks.add_task(delete_file, filepath=filepath) - if isinstance(result, Exception): - logger.error(f"Error: {result}") + if isinstance(result, ProcessException): + logger.error(result.message) raise HTTPException( status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(result), + detail=str(result.message), ) else: utterances, process_times, audio_duration = result diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index dad3cf2..4fbd0db 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -25,20 +25,104 @@ import time import traceback from abc import ABC, abstractmethod +from enum import Enum from pathlib import Path -from typing import Dict, Iterable, List, Tuple, Union +from typing import Iterable, List, Tuple, Union import torch from loguru import logger +from pydantic import BaseModel +from typing_extensions import Annotated from wordcab_transcribe.logging import time_and_tell +from wordcab_transcribe.models import ProcessTimes, Timestamps +from wordcab_transcribe.pydantic_annotations import TorchTensorPydanticAnnotation +from wordcab_transcribe.services.concurrency_services import GPUService, URLService from wordcab_transcribe.services.diarization.diarize_service import DiarizeService -from wordcab_transcribe.services.gpu_service import GPUService from wordcab_transcribe.services.post_processing_service import PostProcessingService from wordcab_transcribe.services.transcribe_service import TranscribeService from wordcab_transcribe.services.vad_service import VadService from wordcab_transcribe.utils import early_return, format_segments, read_audio +PydanticTorchTensor = Annotated[torch.Tensor, TorchTensorPydanticAnnotation] + + +class ExceptionSource(str, Enum): + """Exception source enum.""" + + diarization = "diarization" + post_processing = "post_processing" + transcription = "transcription" + + +class ProcessException(BaseModel): + """Process exception model.""" + + source: ExceptionSource + message: str + + +class LocalExecution(BaseModel): + """Local execution model.""" + + index: Union[int, None] + + +class RemoteExecution(BaseModel): + """Remote execution model.""" + + url: str + + +class ASRTask(BaseModel): + """ASR Task model.""" + + audio: Union[List[PydanticTorchTensor], PydanticTorchTensor] + diarization: "DiarizationTask" + duration: float + multi_channel: bool + offset_start: Union[float, None] + post_processing: "PostProcessingTask" + process_times: ProcessTimes + timestamps_format: Timestamps + transcription: "TranscriptionTask" + word_timestamps: bool + + +class DiarizationTask(BaseModel): + """Diarization Task model.""" + + execution: Union[LocalExecution, RemoteExecution, None] + num_speakers: int + result: Union[ProcessException, List[dict], None] = None + + +class PostProcessingTask(BaseModel): + """Post Processing Task model.""" + + result: Union[ProcessException, List[dict], None] = None + + +class TranscriptionOptions(BaseModel): + """Transcription options model.""" + + compression_ratio_threshold: float + condition_on_previous_text: bool + internal_vad: bool + log_prob_threshold: float + no_speech_threshold: float + repetition_penalty: float + source_lang: str + vocab: Union[List[str], None] + + +class TranscriptionTask(BaseModel): + """Transcription Task model.""" + + execution: Union[LocalExecution, RemoteExecution] + options: TranscriptionOptions + result: Union[ProcessException, List[dict], None] = None + class ASRService(ABC): """Base ASR Service module that handle all AI interactions and batch processing.""" @@ -55,17 +139,6 @@ def __init__(self) -> None: logger.info(f"NVIDIA GPUs available: {self.num_gpus}") self.num_cpus = os.cpu_count() - self.sample_rate = ( - 16000 # The sample rate to use for inference for all audio files (Hz) - ) - - self.queues = None # the queue to store requests - self.queue_locks = None # the locks to access the queues - self.needs_processing = ( - None # the flag to indicate if the queue needs processing - ) - self.needs_processing_timer = None # the timer to schedule processing - if self.num_gpus > 1 and self.device == "cuda": self.device_index = list(range(self.num_gpus)) else: @@ -93,20 +166,36 @@ def __init__( multiscale_weights: List[float], extra_languages: List[str], extra_languages_model_paths: List[str], + transcribe_server_urls: Union[List[str], None], + diarize_server_urls: Union[List[str], None], debug_mode: bool, ) -> None: """ Initialize the ASRAsyncService class. Args: - whisper_model (str): The path to the whisper model. - compute_type (str): The compute type to use for inference. - window_lengths (List[int]): The window lengths to use for diarization. - shift_lengths (List[int]): The shift lengths to use for diarization. - multiscale_weights (List[float]): The multiscale weights to use for diarization. - extra_languages (List[str]): The list of extra languages to support. - extra_languages_model_paths (List[str]): The list of paths to the extra language models. - debug_mode (bool): Whether to run in debug mode. + whisper_model (str): + The path to the whisper model. + compute_type (str): + The compute type to use for inference. + window_lengths (List[int]): + The window lengths to use for diarization. + shift_lengths (List[int]): + The shift lengths to use for diarization. + multiscale_weights (List[float]): + The multiscale weights to use for diarization. + extra_languages (List[str]): + The list of extra languages to support. + extra_languages_model_paths (List[str]): + The list of paths to the extra language models. + use_remote_servers (bool): + Whether to use remote servers for transcription and diarization. + transcribe_server_urls (Union[List[str], None]): + The list of URLs to the remote transcription servers. + diarize_server_urls (Union[List[str], None]): + The list of URLs to the remote diarization servers. + debug_mode (bool): + Whether to run in debug mode. """ super().__init__() @@ -138,6 +227,20 @@ def __init__( "temperature": 0.0, } + if transcribe_server_urls is not None: + self.use_remote_transcription = True + self.transcription_url_handler = URLService( + remote_urls=transcribe_server_urls + ) + else: + self.use_remote_transcription = False + + if diarize_server_urls is not None: + self.use_remote_diarization = True + self.diarization_url_handler = URLService(remote_urls=diarize_server_urls) + else: + self.use_remote_diarization = False + self.debug_mode = debug_mode async def inference_warmup(self) -> None: @@ -155,7 +258,7 @@ async def inference_warmup(self) -> None: multi_channel=False, source_lang="en", timestamps_format="s", - vocab=[], + vocab=None, word_timestamps=False, internal_vad=False, repetition_penalty=1.0, @@ -175,7 +278,7 @@ async def process_input( # noqa: C901 multi_channel: bool, source_lang: str, timestamps_format: str, - vocab: List[str], + vocab: Union[List[str], None], word_timestamps: bool, internal_vad: bool, repetition_penalty: float, @@ -183,7 +286,7 @@ async def process_input( # noqa: C901 log_prob_threshold: float, no_speech_threshold: float, condition_on_previous_text: bool, - ) -> Union[Tuple[List[dict], Dict[str, float], float], Exception]: + ) -> Union[Tuple[List[dict], ProcessTimes, float], Exception]: """Process the input request and return the results. This method will create a task and add it to the appropriate queues. @@ -209,7 +312,7 @@ async def process_input( # noqa: C901 Source language of the audio file. timestamps_format (str): Timestamps format to use. - vocab (List[str]): + vocab (Union[List[str], None]): List of words to use for the vocabulary. word_timestamps (bool): Whether to return word timestamps or not. @@ -230,11 +333,11 @@ async def process_input( # noqa: C901 to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. Returns: - Union[Tuple[List[dict], Dict[str, float], float], Exception]: + Union[Tuple[List[dict], ProcessTimes, float], Exception]: The results of the ASR pipeline or an exception if something went wrong. Results are returned as a tuple of the following: * List[dict]: The final results of the ASR pipeline. - * Dict[str, float]: the process times for each step. + * ProcessTimes: The process times of each step of the ASR pipeline. * float: The audio duration """ if isinstance(filepath, list): @@ -254,213 +357,205 @@ async def process_input( # noqa: C901 filepath, offset_start=offset_start, offset_end=offset_end ) - task = { - "input": audio, - "offset_start": offset_start, - "duration": duration, - "num_speakers": num_speakers, - "diarization": diarization, - "multi_channel": multi_channel, - "source_lang": source_lang, - "timestamps_format": timestamps_format, - "vocab": vocab, - "word_timestamps": word_timestamps, - "internal_vad": internal_vad, - "repetition_penalty": repetition_penalty, - "compression_ratio_threshold": compression_ratio_threshold, - "log_prob_threshold": log_prob_threshold, - "no_speech_threshold": no_speech_threshold, - "condition_on_previous_text": condition_on_previous_text, - "transcription_result": None, - "transcription_done": asyncio.Event(), - "diarization_result": None, - "diarization_done": asyncio.Event(), - "post_processing_result": None, - "post_processing_done": asyncio.Event(), - "process_times": {}, - } + gpu_index = None + if self.use_remote_transcription: + _url = await self.transcription_url_handler.next_url() + transcription_execution = RemoteExecution(url=_url) + else: + gpu_index = await self.gpu_handler.get_device() + transcription_execution = LocalExecution(index=gpu_index) - gpu_index = -1 # Placeholder to track if GPU is acquired + if diarization and multi_channel is False: + if self.use_remote_diarization: + _url = await self.diarization_url_handler.next_url() + diarization_execution = RemoteExecution(url=_url) + else: + if gpu_index is None: + gpu_index = await self.gpu_handler.get_device() - try: - # If GPU is required, acquire one - if self.device == "cuda": - gpu_index = await self.gpu_handler.get_device() - logger.info(f"Using GPU {gpu_index} for the task") + diarization_execution = LocalExecution(index=gpu_index) + else: + diarization_execution = None + + task = ASRTask( + audio=audio, + diarization=DiarizationTask( + execution=diarization_execution, num_speakers=num_speakers + ), + duration=duration, + multi_channel=multi_channel, + offset_start=offset_start, + post_processing=PostProcessingTask(), + process_times=ProcessTimes(), + timestamps_format=timestamps_format, + transcription=TranscriptionTask( + execution=transcription_execution, + options=TranscriptionOptions( + compression_ratio_threshold=compression_ratio_threshold, + condition_on_previous_text=condition_on_previous_text, + internal_vad=internal_vad, + log_prob_threshold=log_prob_threshold, + no_speech_threshold=no_speech_threshold, + repetition_penalty=repetition_penalty, + source_lang=source_lang, + vocab=vocab, + ), + ), + word_timestamps=word_timestamps, + ) + try: start_process_time = time.time() - asyncio.get_event_loop().run_in_executor( + transcription_task = asyncio.get_event_loop().run_in_executor( None, - functools.partial( - self.process_transcription, task, gpu_index, self.debug_mode - ), + functools.partial(self.process_transcription, task, self.debug_mode), + ) + diarization_task = asyncio.get_event_loop().run_in_executor( + None, + functools.partial(self.process_diarization, task, self.debug_mode), ) - if diarization and multi_channel is False: - asyncio.get_event_loop().run_in_executor( - None, - functools.partial( - self.process_diarization, task, gpu_index, self.debug_mode - ), - ) - else: - task["process_times"]["diarization"] = None - task["diarization_done"].set() - - await task["transcription_done"].wait() - await task["diarization_done"].wait() + await transcription_task + await diarization_task - if isinstance(task["diarization_result"], Exception): - self.gpu_handler.release_device(gpu_index) - gpu_index = -1 - return task["diarization_result"] + if isinstance(task.diarization.result, ProcessException): + return task.diarization.result if ( diarization - and task["diarization_result"] is None + and task.diarization.result is None and multi_channel is False ): # Empty audio early return return early_return(duration=duration) - if isinstance(task["transcription_result"], Exception): - self.gpu_handler.release_device(gpu_index) - gpu_index = -1 - return task["transcription_result"] - - self.gpu_handler.release_device(gpu_index) - gpu_index = -1 + if isinstance(task.transcription.result, ProcessException): + return task.transcription.result - asyncio.get_event_loop().run_in_executor( + await asyncio.get_event_loop().run_in_executor( None, functools.partial(self.process_post_processing, task) ) - await task["post_processing_done"].wait() + if isinstance(task.post_processing.result, ProcessException): + return task.post_processing.result - if isinstance(task["post_processing_result"], Exception): - return task["post_processing_result"] + task.process_times.total = time.time() - start_process_time - result: List[dict] = task.pop("post_processing_result") - process_times: Dict[str, float] = task.pop("process_times") - process_times["total"]: float = time.time() - start_process_time + return task.post_processing.result, task.process_times, duration - del task # Delete the task to free up memory - - return result, process_times, duration except Exception as e: return e + finally: - # Ensure GPU is released if it was acquired - if gpu_index != -1: + del task + + if gpu_index is not None: self.gpu_handler.release_device(gpu_index) - def process_transcription( - self, task: dict, gpu_index: int, debug_mode: bool - ) -> None: + def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: """ Process a task of transcription and update the task with the result. Args: - task (dict): The task and its parameters. - gpu_index (int): The GPU index to use for the transcription. + task (ASRTask): The task and its parameters. debug_mode (bool): Whether to run in debug mode or not. Returns: None: The task is updated with the result. """ try: - result, process_time = time_and_tell( - lambda: self.services["transcription"]( - task["input"], - source_lang=task["source_lang"], - model_index=gpu_index, - suppress_blank=False, - vocab=None if task["vocab"] == [] else task["vocab"], - word_timestamps=True, - internal_vad=task["internal_vad"], - repetition_penalty=task["repetition_penalty"], - compression_ratio_threshold=task["compression_ratio_threshold"], - log_prob_threshold=task["log_prob_threshold"], - no_speech_threshold=task["no_speech_threshold"], - condition_on_previous_text=task["condition_on_previous_text"], - vad_service=self.services["vad"] if task["multi_channel"] else None, - ), - func_name="transcription", - debug_mode=debug_mode, - ) + if isinstance(task.transcription.execution, LocalExecution): + result, process_time = time_and_tell( + lambda: self.services["transcription"]( + task.audio, + model_index=task.transcription.execution.index, + suppress_blank=False, + word_timestamps=True, + **task.transcription.options.model_dump(), + ), + func_name="transcription", + debug_mode=debug_mode, + ) + elif isinstance(task.transcription.execution, RemoteExecution): + raise NotImplementedError("Remote execution is not implemented yet.") except Exception as e: - result = Exception( - f"Error in transcription gpu {gpu_index}: {e}\n{traceback.format_exc()}" + result = ProcessException( + source=ExceptionSource.transcription, + message=f"Error in transcription: {e}\n{traceback.format_exc()}", ) process_time = None finally: - task["process_times"]["transcription"] = process_time - task["transcription_result"] = result - task["transcription_done"].set() + task.process_times.transcription = process_time + task.transcription.result = result return None - def process_diarization(self, task: dict, gpu_index: int, debug_mode: bool) -> None: + def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: """ Process a task of diarization. Args: - task (dict): The task and its parameters. - gpu_index (int): The GPU index to use for the diarization. + task (ASRTask): The task and its parameters. debug_mode (bool): Whether to run in debug mode or not. Returns: None: The task is updated with the result. """ try: - result, process_time = time_and_tell( - lambda: self.services["diarization"]( - task["input"], - audio_duration=task["duration"], - oracle_num_speakers=task["num_speakers"], - model_index=gpu_index, - vad_service=self.services["vad"], - ), - func_name="diarization", - debug_mode=debug_mode, - ) + if isinstance(task.diarization.execution, LocalExecution): + result, process_time = time_and_tell( + lambda: self.services["diarization"]( + waveform=task.audio, + audio_duration=task.duration, + oracle_num_speakers=task.diarization.num_speakers, + model_index=task.diarization.execution.index, + vad_service=self.services["vad"], + ), + func_name="diarization", + debug_mode=debug_mode, + ) + elif isinstance(task.diarization.execution, RemoteExecution): + raise NotImplementedError("Remote execution is not implemented yet.") + elif task.diarization.execution is None: + result = None + process_time = None except Exception as e: - result = Exception(f"Error in diarization: {e}\n{traceback.format_exc()}") + result = ProcessException( + source=ExceptionSource.diarization, + message=f"Error in diarization: {e}\n{traceback.format_exc()}", + ) process_time = None finally: - task["process_times"]["diarization"] = process_time - task["diarization_result"] = result - task["diarization_done"].set() + task.process_times.diarization = process_time + task.diarization.result = result return None - def process_post_processing(self, task: dict) -> None: + def process_post_processing(self, task: ASRTask) -> None: """ Process a task of post-processing. Args: - task (dict): The task and its parameters. + task (ASRTask): The task and its parameters. Returns: None: The task is updated with the result. """ try: total_post_process_time = 0 - diarization = task["diarization"] - multi_channel = task["multi_channel"] - word_timestamps = task["word_timestamps"] + diarization = False if task.diarization.execution is None else True - if multi_channel: + if task.multi_channel: utterances, process_time = time_and_tell( lambda: self.services[ "post_processing" - ].multi_channel_speaker_mapping(task["transcription_result"]), - func_name="dual_channel_speaker_mapping", + ].multi_channel_speaker_mapping(task.transcription.result), + func_name="multi_channel_speaker_mapping", debug_mode=self.debug_mode, ) total_post_process_time += process_time @@ -468,7 +563,7 @@ def process_post_processing(self, task: dict) -> None: else: formatted_segments, process_time = time_and_tell( lambda: format_segments( - segments=task["transcription_result"], + segments=task.transcription.result, word_timestamps=True, ), func_name="format_segments", @@ -482,8 +577,8 @@ def process_post_processing(self, task: dict) -> None: "post_processing" ].single_channel_speaker_mapping( transcript_segments=formatted_segments, - speaker_timestamps=task["diarization_result"], - word_timestamps=word_timestamps, + speaker_timestamps=task.diarization.result, + word_timestamps=task.word_timestamps, ), func_name="single_channel_speaker_mapping", debug_mode=self.debug_mode, @@ -498,10 +593,10 @@ def process_post_processing(self, task: dict) -> None: ].final_processing_before_returning( utterances=utterances, diarization=diarization, - multi_channel=task["multi_channel"], - offset_start=task["offset_start"], - timestamps_format=task["timestamps_format"], - word_timestamps=word_timestamps, + multi_channel=task.multi_channel, + offset_start=task.offset_start, + timestamps_format=task.timestamps_format, + word_timestamps=task.word_timestamps, ), func_name="final_processing_before_returning", debug_mode=self.debug_mode, @@ -509,15 +604,15 @@ def process_post_processing(self, task: dict) -> None: total_post_process_time += process_time except Exception as e: - final_utterances = Exception( - f"Error in post-processing: {e}\n{traceback.format_exc()}" + final_utterances = ProcessException( + source=ExceptionSource.post_processing, + message=f"Error in post-processing: {e}\n{traceback.format_exc()}", ) total_post_process_time = None finally: - task["process_times"]["post_processing"] = total_post_process_time - task["post_processing_result"] = final_utterances - task["post_processing_done"].set() + task.process_times.post_processing = total_post_process_time + task.post_processing.result = final_utterances return None diff --git a/src/wordcab_transcribe/services/gpu_service.py b/src/wordcab_transcribe/services/concurrency_services.py similarity index 66% rename from src/wordcab_transcribe/services/gpu_service.py rename to src/wordcab_transcribe/services/concurrency_services.py index 1b0af3b..b604db0 100644 --- a/src/wordcab_transcribe/services/gpu_service.py +++ b/src/wordcab_transcribe/services/concurrency_services.py @@ -20,7 +20,7 @@ """GPU service class to handle gpu availability for models.""" import asyncio -from typing import Any, Dict, List +from typing import List class GPUService: @@ -37,9 +37,6 @@ def __init__(self, device: str, device_index: List[int]) -> None: self.device: str = device self.device_index: List[int] = device_index - # Initialize the models dictionary that will hold the models for each GPU. - self.models: Dict[int, Any] = {} - self.queue = asyncio.Queue(maxsize=len(self.device_index)) for idx in self.device_index: self.queue.put_nowait(idx) @@ -67,3 +64,42 @@ def release_device(self, device_index: int) -> None: """ if not any(item == device_index for item in self.queue._queue): self.queue.put_nowait(device_index) + + +class URLService: + """URL service class to handle multiple remote URLs.""" + + def __init__(self, remote_urls: List[str]) -> None: + """ + Initialize the URL service. + + Args: + remote_urls (List[str]): List of remote URLs to use. + """ + self.remote_urls: List[str] = remote_urls + + # If there is only one URL, we don't need to use a queue + if len(self.remote_urls) == 1: + self.queue = None + else: + self.queue = asyncio.Queue(maxsize=len(self.remote_urls)) + for url in self.remote_urls: + self.queue.put_nowait(url) + + async def next_url(self) -> str: + """ + We use this to iterate equally over the available URLs. + + Returns: + str: Next available URL. + """ + if self.queue is None: + return self.remote_urls[0] + + else: + url = self.queue.get_nowait() + # Unlike GPU we don't want to block remote ASR requests. + # So we re-insert the URL back into the queue after getting it. + self.queue.put_nowait(url) + + return url diff --git a/src/wordcab_transcribe/services/post_processing_service.py b/src/wordcab_transcribe/services/post_processing_service.py index 92db9ec..8838006 100644 --- a/src/wordcab_transcribe/services/post_processing_service.py +++ b/src/wordcab_transcribe/services/post_processing_service.py @@ -22,6 +22,7 @@ import itertools from typing import List, Union +from wordcab_transcribe.models import Timestamps from wordcab_transcribe.utils import convert_timestamp, format_punct, is_empty_string @@ -320,19 +321,25 @@ def final_processing_before_returning( diarization: bool, multi_channel: bool, offset_start: Union[float, None], - timestamps_format: str, + timestamps_format: Timestamps, word_timestamps: bool, ) -> List[dict]: """ Do final processing before returning the utterances to the API. Args: - utterances (List[dict]): List of utterances. - diarization (bool): Whether diarization is enabled. - multi_channel (bool): Whether multi-channel is enabled. - offset_start (Union[float, None]): Offset start. - timestamps_format (str): Timestamps format used for conversion. - word_timestamps (bool): Whether to include word timestamps. + utterances (List[dict]): + List of utterances. + diarization (bool): + Whether diarization is enabled. + multi_channel (bool): + Whether multi-channel is enabled. + offset_start (Union[float, None]): + Offset start. + timestamps_format (Timestamps): + Timestamps format. Can be `s`, `ms`, or `hms`. + word_timestamps (bool): + Whether to include word timestamps. Returns: List[dict]: List of utterances with final processing. diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index 9e64469..59cac1e 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -25,8 +25,6 @@ from faster_whisper import WhisperModel from loguru import logger -from wordcab_transcribe.services.vad_service import VadService - class FasterWhisperModel(NamedTuple): """Faster Whisper Model.""" @@ -93,7 +91,6 @@ def __call__( log_prob_threshold: float = -1.0, no_speech_threshold: float = 0.6, condition_on_previous_text: bool = True, - vad_service: Union[VadService, None] = None, ) -> Union[List[dict], List[List[dict]]]: """ Run inference with the transcribe model. @@ -127,8 +124,6 @@ def __call__( If True, the previous output of the model is provided as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - vad_service (Union[VadService, None]): - VADService to use for voice activity detection in the multi_channel case. Defaults to None. Returns: Union[List[dict], List[List[dict]]]: List of transcriptions. If the task is a multi_channel task, diff --git a/src/wordcab_transcribe/utils.py b/src/wordcab_transcribe/utils.py index 8b62ad1..9c95c44 100644 --- a/src/wordcab_transcribe/utils.py +++ b/src/wordcab_transcribe/utils.py @@ -38,6 +38,8 @@ if TYPE_CHECKING: from fastapi import UploadFile +from wordcab_transcribe.models import Timestamps + # pragma: no cover async def async_run_subprocess(command: List[str]) -> tuple: @@ -120,14 +122,14 @@ async def check_num_channels(filepath: Union[str, Path]) -> int: def convert_timestamp( - timestamp: float, target: str, round_digits: Optional[int] = 3 + timestamp: float, target: Timestamps, round_digits: Optional[int] = 3 ) -> Union[str, float]: """ Use the right function to convert the timestamp. Args: timestamp (float): Timestamp to convert. - target (str): Timestamp to convert. + target (Timestamps): Target timestamp format. round_digits (int, optional): Number of digits to round the timestamp. Defaults to 3. Returns: @@ -136,11 +138,11 @@ def convert_timestamp( Raises: ValueError: If the target is invalid. Valid targets are: ms, hms, s. """ - if target == "ms": + if target == Timestamps.milliseconds: return round(_convert_s_to_ms(timestamp), round_digits) - elif target == "hms": + elif target == Timestamps.hour_minute_second: return _convert_s_to_hms(timestamp) - elif target == "s": + elif target == Timestamps.seconds: return round(timestamp, round_digits) else: raise ValueError( diff --git a/tests/test_config.py b/tests/test_config.py index 9004106..789d7ad 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -77,8 +77,8 @@ def test_config() -> None: assert settings.whisper_model == "large-v2" assert settings.compute_type == "float16" - assert settings.extra_languages == [] - assert settings.extra_languages_model_paths == {} + assert settings.extra_languages is None + assert settings.extra_languages_model_paths is None assert settings.window_lengths == [1.5, 1.25, 1.0, 0.75, 0.5] assert settings.shift_lengths == [0.75, 0.625, 0.5, 0.375, 0.25] diff --git a/tests/test_models.py b/tests/test_models.py index c4bc1bb..e1447b0 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -157,7 +157,7 @@ def test_audio_request() -> None: assert request.multi_channel is True assert request.source_lang == "en" assert request.timestamps == "s" - assert request.vocab == [] + assert request.vocab is None assert request.word_timestamps is False assert request.internal_vad is False assert request.repetition_penalty == 1.2 @@ -450,7 +450,7 @@ def test_cortex_payload() -> None: assert payload.multi_channel is False assert payload.source_lang == "en" assert payload.timestamps == "s" - assert payload.vocab == [] + assert payload.vocab is None assert payload.word_timestamps is False assert payload.internal_vad is False assert payload.repetition_penalty == 1.2