diff --git a/.env b/.env index 96871f3..e069d39 100644 --- a/.env +++ b/.env @@ -69,17 +69,9 @@ ASR_TYPE="async" # # --------------------------------------------- ENDPOINTS CONFIGURATION ---------------------------------------------- # # -# Include the `audio` endpoint in the API. This endpoint is used to process uploaded local audio files. -AUDIO_FILE_ENDPOINT=True -# Include the `audio-url` endpoint in the API. This endpoint is used to process audio files from a URL. -AUDIO_URL_ENDPOINT=True # Include the cortex endpoint in the API. This endpoint is used to process audio files from the Cortex API. # Use this only if you deploy the API using Cortex and Kubernetes. CORTEX_ENDPOINT=True -# Include the `youtube` endpoint in the API. This endpoint is used to process audio files from YouTube URLs. -YOUTUBE_ENDPOINT=True -# Include the `live` endpoint in the API. This endpoint is used to process live audio streams. -LIVE_ENDPOINT=False # # ---------------------------------------- API AUTHENTICATION CONFIGURATION ------------------------------------------ # # The API authentication is used to control the access to the API endpoints. diff --git a/notebooks/audio_url_inference.py b/notebooks/audio_url_inference.py new file mode 100644 index 0000000..558e3eb --- /dev/null +++ b/notebooks/audio_url_inference.py @@ -0,0 +1,30 @@ +import json +import requests + +headers = {"accept": "application/json", "Content-Type": "application/json"} +params = {"url": "https://github.com/Wordcab/wordcab-python/raw/main/tests/sample_1.mp3"} + + +data = { + "offset_start": None, + "offset_end": None, + "num_speakers": -1, # Leave at -1 to guess the number of speakers + "diarization": True, # Longer processing time but speaker segment attribution + "source_lang": "en", # optional, default is "en" + "timestamps": "s", # optional, default is "s". Can be "s", "ms" or "hms". + "internal_vad": False, # optional, default is False + "vocab": ["Martha's Flowers", "Thomas", "Randal"], # optional, default is None + "word_timestamps": False, # optional, default is False +} + +response = requests.post( + "http://localhost:5001/api/v1/audio-url", + headers=headers, + params=params, + data=json.dumps(data), +) + +r_json = response.json() + +with open("data/audio_url_output.json", "w", encoding="utf-8") as f: + json.dump(r_json, f, indent=4, ensure_ascii=False) diff --git a/notebooks/transcribe_endpoint_only.py b/notebooks/transcribe_endpoint_only.py new file mode 100644 index 0000000..efeddda --- /dev/null +++ b/notebooks/transcribe_endpoint_only.py @@ -0,0 +1,124 @@ +import asyncio +import io +import json +from typing import List, Tuple, Union + +import aiohttp +import soundfile as sf +import torch +import torchaudio +from pydantic import BaseModel +from tensorshare import Backend, TensorShare + + +def read_audio( + audio: Union[str, bytes], + offset_start: Union[float, None] = None, + offset_end: Union[float, None] = None, + sample_rate: int = 16000, +) -> Tuple[torch.Tensor, float]: + """ + Read an audio file and return the audio tensor. + + Args: + audio (Union[str, bytes]): + Path to the audio file or the audio bytes. + offset_start (Union[float, None], optional): + When to start reading the audio file. Defaults to None. + offset_end (Union[float, None], optional): + When to stop reading the audio file. Defaults to None. + sample_rate (int): + The sample rate of the audio file. Defaults to 16000. + + Returns: + Tuple[torch.Tensor, float]: The audio tensor and the audio duration. + """ + if isinstance(audio, str): + wav, sr = torchaudio.load(audio) + elif isinstance(audio, bytes): + with io.BytesIO(audio) as buffer: + wav, sr = sf.read( + buffer, format="RAW", channels=1, samplerate=16000, subtype="PCM_16" + ) + wav = torch.from_numpy(wav).unsqueeze(0) + else: + raise ValueError( + f"Invalid audio type. Must be either str or bytes, got: {type(audio)}." + ) + + if wav.size(0) > 1: + wav = wav.mean(dim=0, keepdim=True) + + if sr != sample_rate: + transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) + wav = transform(wav) + sr = sample_rate + + audio_duration = float(wav.shape[1]) / sample_rate + + # Convert start and end times to sample indices based on the new sample rate + if offset_start is not None: + start_sample = int(offset_start * sr) + else: + start_sample = 0 + + if offset_end is not None: + end_sample = int(offset_end * sr) + else: + end_sample = wav.shape[1] + + # Trim the audio based on the new start and end samples + wav = wav[:, start_sample:end_sample] + + return wav.squeeze(0), audio_duration + + +class TranscribeRequest(BaseModel): + """Request model for the transcribe endpoint.""" + + audio: Union[TensorShare, List[TensorShare]] + compression_ratio_threshold: float + condition_on_previous_text: bool + internal_vad: bool + log_prob_threshold: float + no_speech_threshold: float + repetition_penalty: float + source_lang: str + vocab: Union[List[str], None] + + +async def main(): + + audio, _ = read_audio("data/HL_Podcast_1.mp3") + ts = TensorShare.from_dict({"audio": audio}, backend=Backend.TORCH) + + data = TranscribeRequest( + audio=ts, + source_lang="en", + compression_ratio_threshold=2.4, + condition_on_previous_text=True, + internal_vad=False, + log_prob_threshold=-1.0, + no_speech_threshold=0.6, + repetition_penalty=1.0, + vocab=None, + ) + + async with aiohttp.ClientSession() as session: + async with session.post( + url="http://0.0.0.0:5002/api/v1/transcribe", + data=data.model_dump_json(), + headers={"Content-Type": "application/json"}, + ) as response: + if response.status != 200: + raise Exception( + f"Remote transcription failed with status {response.status}." + ) + else: + r = await response.json() + + with open("remote_test.json", "w") as f: + f.write(json.dumps(r, indent=4)) + + +asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index d2002e9..75e9fca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,10 +29,19 @@ classifiers = [ dependencies = [ "aiohttp>=3.8.4", "aiofiles>=23.1.0", + "ctranslate2>=3.18.0", + "faster-whisper @ git+https://github.com/Wordcab/faster-whisper@master", "ffmpeg-python>=0.2.0", + "librosa>=0.9.0", "loguru>=0.6.0", + "numpy==1.23.1", + "onnxruntime>=1.15.0", "pydantic>=1.10.9", "python-dotenv>=1.0.0", + "tensorshare>=0.1.1", + "torch>=2.0.0", + "torchaudio>=2.0.1", + "wget>=3.2.0", "yt-dlp>=2023.3.4", ] @@ -48,16 +57,6 @@ path = "src/wordcab_transcribe/__init__.py" allow-direct-references = true [project.optional-dependencies] -inference = [ - "ctranslate2>=3.18.0", - "faster-whisper @ git+https://github.com/Wordcab/faster-whisper@master", - "librosa>=0.9.0", - "numpy==1.23.1", - "onnxruntime>=1.15.0", - "torch>=2.0.0", - "torchaudio>=2.0.1", - "wget>=3.2.0", -] runtime = [ "argon2-cffi>=21.3.0", "fastapi>=0.96.0", @@ -67,7 +66,6 @@ runtime = [ "svix>=0.85.1", "uvicorn>=0.21.1", "websockets>=11.0.3", - "wordcab_transcribe[inference]", ] quality = [ "black>=22.10.0", @@ -79,9 +77,8 @@ tests = [ "pytest>=7.4", "pytest-asyncio>=0.21.1", "pytest-cov>=4.1", - "wordcab_transcribe[inference]", ] -dev = ["wordcab_transcribe[quality,inference,runtime,tests]"] +dev = ["wordcab_transcribe[quality,runtime,tests]"] [tool.hatch.envs.runtime] features = [ diff --git a/src/wordcab_transcribe/config.py b/src/wordcab_transcribe/config.py index a4e0e0e..5d992ae 100644 --- a/src/wordcab_transcribe/config.py +++ b/src/wordcab_transcribe/config.py @@ -53,12 +53,8 @@ class Settings: multiscale_weights: List[float] # ASR type configuration asr_type: Literal["async", "live", "only_transcription", "only_diarization"] - # Endpoints configuration - audio_file_endpoint: bool - audio_url_endpoint: bool + # Endpoint configuration cortex_endpoint: bool - youtube_endpoint: bool - live_endpoint: bool # API authentication configuration username: str password: str @@ -134,16 +130,6 @@ def compute_type_must_be_valid(cls, value: str): # noqa: B902, N805 return value - @field_validator("asr_type") - def asr_type_must_be_valid(cls, value: str): # noqa: B902, N805 - """Check that the ASR type is valid.""" - if value not in {"async", "live"}: - raise ValueError( - f"{value} is not a valid ASR type. Choose between `async` or `live`." - ) - - return value - @field_validator("openssl_algorithm") def openssl_algorithm_must_be_valid(cls, value: str): # noqa: B902, N805 """Check that the OpenSSL algorithm is valid.""" @@ -168,16 +154,6 @@ def access_token_expire_minutes_must_be_valid(cls, value: int): # noqa: B902, N def __post_init__(self): """Post initialization checks.""" - endpoints = [ - self.audio_file_endpoint, - self.audio_url_endpoint, - self.cortex_endpoint, - self.youtube_endpoint, - self.live_endpoint, - ] - if not any(endpoints): - raise ValueError("At least one endpoint configuration must be set to True.") - if self.debug is False: if self.username == "admin" or self.username is None: # noqa: S105 logger.warning( @@ -281,11 +257,7 @@ def __post_init__(self): # ASR type asr_type=getenv("ASR_TYPE", "async"), # Endpoints configuration - audio_file_endpoint=getenv("AUDIO_FILE_ENDPOINT", True), - audio_url_endpoint=getenv("AUDIO_URL_ENDPOINT", True), cortex_endpoint=getenv("CORTEX_ENDPOINT", True), - youtube_endpoint=getenv("YOUTUBE_ENDPOINT", True), - live_endpoint=getenv("LIVE_ENDPOINT", False), # API authentication configuration username=getenv("USERNAME", "admin"), password=getenv("PASSWORD", "admin"), diff --git a/src/wordcab_transcribe/dependencies.py b/src/wordcab_transcribe/dependencies.py index f99ef23..5580930 100644 --- a/src/wordcab_transcribe/dependencies.py +++ b/src/wordcab_transcribe/dependencies.py @@ -26,7 +26,12 @@ from loguru import logger from wordcab_transcribe.config import settings -from wordcab_transcribe.services.asr_service import ASRAsyncService, ASRLiveService +from wordcab_transcribe.services.asr_service import ( + ASRAsyncService, + ASRDiarizationOnly, + ASRLiveService, + ASRTranscriptionOnly, +) from wordcab_transcribe.utils import ( check_ffmpeg, download_model, @@ -57,9 +62,20 @@ debug_mode=settings.debug, ) elif settings.asr_type == "only_transcription": - asr = None + asr = ASRTranscriptionOnly( + whisper_model=settings.whisper_model, + compute_type=settings.compute_type, + extra_languages=settings.extra_languages, + extra_languages_model_paths=settings.extra_languages_model_paths, + debug_mode=settings.debug, + ) elif settings.asr_type == "only_diarization": - asr = None + asr = ASRDiarizationOnly( + window_lengths=settings.window_lengths, + shift_lengths=settings.shift_lengths, + multiscale_weights=settings.multiscale_weights, + debug_mode=settings.debug, + ) else: raise ValueError(f"Invalid ASR type: {settings.asr_type}") @@ -75,7 +91,7 @@ async def lifespan(app: FastAPI) -> None: " https://github.com/Wordcab/wordcab-transcribe/issues" ) - if settings.asr_type == "async" or settings.asr_type == "remote_transcribe": + if settings.asr_type == "async" or settings.asr_type == "only_transcription": if check_ffmpeg() is False: logger.warning( "FFmpeg is not installed on the host machine.\n" diff --git a/src/wordcab_transcribe/logging.py b/src/wordcab_transcribe/logging.py index 96e46f2..e1d0fcb 100644 --- a/src/wordcab_transcribe/logging.py +++ b/src/wordcab_transcribe/logging.py @@ -20,9 +20,11 @@ """Logging module to add a logging middleware to the Wordcab Transcribe API.""" +import asyncio import sys import time import uuid +from functools import partial from typing import Any, Awaitable, Callable, Tuple from loguru import logger @@ -93,7 +95,42 @@ def time_and_tell( The appropriate wrapper for the function. """ start_time = time.time() - result = func() + result = func + process_time = time.time() - start_time + + if debug_mode: + logger.debug(f"{func_name} executed in {process_time:.4f} secs") + + return result, process_time + + +async def time_and_tell_async( + func: Callable, func_name: str, debug_mode: bool +) -> Tuple[Any, float]: + """ + This decorator logs the execution time of an async function only if the debug setting is True. + + Args: + func: The function to call in the wrapper. + func_name: The name of the function for logging purposes. + debug_mode: The debug setting for logging purposes. + + Returns: + The appropriate wrapper for the function. + """ + start_time = time.time() + + if asyncio.iscoroutinefunction(func) or asyncio.iscoroutine(func): + result = await func + else: + loop = asyncio.get_event_loop() + if isinstance(func, partial): + result = await loop.run_in_executor( + None, func.func, *func.args, **func.keywords + ) + else: + result = await loop.run_in_executor(None, func) + process_time = time.time() - start_time if debug_mode: diff --git a/src/wordcab_transcribe/models.py b/src/wordcab_transcribe/models.py index 32a9b68..cf1df0d 100644 --- a/src/wordcab_transcribe/models.py +++ b/src/wordcab_transcribe/models.py @@ -20,9 +20,11 @@ """Models module of the Wordcab Transcribe.""" from enum import Enum -from typing import List, Literal, Optional, Union +from typing import List, Literal, NamedTuple, Optional, Union +from faster_whisper.transcribe import Segment from pydantic import BaseModel, field_validator +from tensorshare import TensorShare class ProcessTimes(BaseModel): @@ -48,7 +50,7 @@ class Word(BaseModel): word: str start: float end: float - score: float + probability: float class Utterance(BaseModel): @@ -57,8 +59,8 @@ class Utterance(BaseModel): text: str start: Union[float, str] end: Union[float, str] - speaker: Optional[int] - words: Optional[List[Word]] + speaker: Union[int, None] = None + words: Union[List[Word], None] = None class BaseResponse(BaseModel): @@ -482,12 +484,46 @@ class Config: } -class DiarizeResponse(BaseModel): - """Response model for the diarize endpoint.""" +class DiarizationSegment(NamedTuple): + """Diarization segment model for the API.""" + start: float + end: float + speaker: int + + +class DiarizationOutput(BaseModel): + """Diarization output model for the API.""" + + segments: List[DiarizationSegment] + + +class DiarizationRequest(BaseModel): + """Request model for the diarize endpoint.""" + + audio: TensorShare + duration: float + num_speakers: int + + +class TranscriptionOutput(BaseModel): + """Transcription output model for the API.""" -class TranscribeResponse(BaseModel): - """Response model for the transcribe endpoint.""" + segments: List[Segment] + + +class TranscribeRequest(BaseModel): + """Request model for the transcribe endpoint.""" + + audio: Union[TensorShare, List[TensorShare]] + compression_ratio_threshold: float + condition_on_previous_text: bool + internal_vad: bool + log_prob_threshold: float + no_speech_threshold: float + repetition_penalty: float + source_lang: str + vocab: Union[List[str], None] class Token(BaseModel): diff --git a/src/wordcab_transcribe/pydantic_annotations.py b/src/wordcab_transcribe/pydantic_annotations.py deleted file mode 100644 index c881a08..0000000 --- a/src/wordcab_transcribe/pydantic_annotations.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2023 The Wordcab Team. All rights reserved. -# -# Licensed under the Wordcab Transcribe License 0.1 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://github.com/Wordcab/wordcab-transcribe/blob/main/LICENSE -# -# Except as expressly provided otherwise herein, and to the fullest -# extent permitted by law, Licensor provides the Software (and each -# Contributor provides its Contributions) AS IS, and Licensor -# disclaims all warranties or guarantees of any kind, express or -# implied, whether arising under any law or from any usage in trade, -# or otherwise including but not limited to the implied warranties -# of merchantability, non-infringement, quiet enjoyment, fitness -# for a particular purpose, or otherwise. -# -# See the License for the specific language governing permissions -# and limitations under the License. -"""Custom Pydantic annotations for the Wordcab Transcribe API.""" - -from typing import Any - -import torch -from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler -from pydantic.json_schema import JsonSchemaValue -from pydantic_core import core_schema - - -class TorchTensorPydanticAnnotation: - """Pydantic annotation for torch.Tensor.""" - - @classmethod - def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, - ) -> core_schema.CoreSchema: - """Custom validation and serialization for torch.Tensor.""" - - def validate_tensor(value) -> torch.Tensor: - if not isinstance(value, torch.Tensor): - raise ValueError(f"Expected a torch.Tensor but got {type(value)}") - return value - - return core_schema.chain_schema( - [ - core_schema.no_info_plain_validator_function(validate_tensor), - ] - ) - - @classmethod - def __get_pydantic_json_schema__( - cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - # This is a custom representation for the tensor in JSON Schema. - # Here, it represents a tensor as an object with metadata. - return { - "type": "object", - "properties": { - "dtype": { - "type": "string", - "description": "Data type of the tensor", - }, - "shape": { - "type": "array", - "items": {"type": "integer"}, - "description": "Shape of the tensor", - }, - }, - "required": ["dtype", "shape"], - } diff --git a/src/wordcab_transcribe/router/v1/diarize_endpoint.py b/src/wordcab_transcribe/router/v1/diarize_endpoint.py index d42d646..dfab61c 100644 --- a/src/wordcab_transcribe/router/v1/diarize_endpoint.py +++ b/src/wordcab_transcribe/router/v1/diarize_endpoint.py @@ -21,18 +21,33 @@ from typing import Union -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from fastapi import status as http_status +from loguru import logger -from wordcab_transcribe.models import DiarizeResponse +from wordcab_transcribe.dependencies import asr +from wordcab_transcribe.models import DiarizationOutput, DiarizationRequest +from wordcab_transcribe.services.asr_service import ProcessException router = APIRouter() @router.post( "", - response_model=Union[DiarizeResponse, str], + response_model=Union[DiarizationOutput, str], status_code=http_status.HTTP_200_OK, ) -async def remote_diarization() -> DiarizeResponse: - """Diarize endpoint for the Remote Wordcab Transcribe API.""" +async def remote_diarization( + data: DiarizationRequest, +) -> DiarizationOutput: + """Diarize endpoint for the `only_diarization` asr type.""" + result: DiarizationOutput = await asr.process_input(data) + + if isinstance(result, ProcessException): + logger.error(result.message) + raise HTTPException( + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(result.message), + ) + + return result diff --git a/src/wordcab_transcribe/router/v1/endpoints.py b/src/wordcab_transcribe/router/v1/endpoints.py index 6f5eee9..6ff8e6d 100644 --- a/src/wordcab_transcribe/router/v1/endpoints.py +++ b/src/wordcab_transcribe/router/v1/endpoints.py @@ -35,39 +35,26 @@ api_router = APIRouter() -async_routers = ( - ("audio_file_endpoint", audio_file_router, "/audio", "async"), - ("audio_url_endpoint", audio_url_router, "/audio-url", "async"), - ("youtube_endpoint", youtube_router, "/youtube", "async"), -) -live_routers = ("live_endpoint", live_router, "/live", "live") -transcribe_routers = ( - "transcribe_endpoint", - transcribe_router, - "/transcribe", - "transcription", -) -diarize_routers = ( - "diariaze_endpoint", - diarize_router, - "/diarize", - "diarization", -) +async_routers = [ + (audio_file_router, "/audio", "async"), + (audio_url_router, "/audio-url", "async"), + (youtube_router, "/youtube", "async"), +] +live_routers = (live_router, "/live", "live") +transcribe_routers = (transcribe_router, "/transcribe", "transcription") +diarize_routers = (diarize_router, "/diarize", "diarization") +routers = [] if settings.asr_type == "async": - routers = async_routers + routers.extend(async_routers) elif settings.asr_type == "live": - routers = live_routers + routers.append(live_routers) elif settings.asr_type == "only_transcription": - routers = transcribe_routers + routers.append(transcribe_routers) elif settings.asr_type == "only_diarization": - routers = diarize_routers -else: - raise ValueError(f"Invalid ASR type: {settings.asr_type}") + routers.append(diarize_routers) -for router_items in routers: - endpoint, router, prefix, tags = router_items - # If the endpoint is enabled, include it in the API. - if getattr(settings, endpoint) is True: - api_router.include_router(router, prefix=prefix, tags=[tags]) +for items in routers: + router, prefix, tags = items + api_router.include_router(router, prefix=prefix, tags=[tags]) diff --git a/src/wordcab_transcribe/router/v1/transcribe_endpoint.py b/src/wordcab_transcribe/router/v1/transcribe_endpoint.py index e1002c8..2fa3b6f 100644 --- a/src/wordcab_transcribe/router/v1/transcribe_endpoint.py +++ b/src/wordcab_transcribe/router/v1/transcribe_endpoint.py @@ -19,20 +19,37 @@ # and limitations under the License. """Transcribe endpoint for the Remote Wordcab Transcribe API.""" -from typing import Union +from typing import List, Union -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from fastapi import status as http_status +from loguru import logger -from wordcab_transcribe.models import TranscribeResponse +from wordcab_transcribe.dependencies import asr +from wordcab_transcribe.models import TranscribeRequest, TranscriptionOutput +from wordcab_transcribe.services.asr_service import ProcessException router = APIRouter() @router.post( "", - response_model=Union[TranscribeResponse, str], + response_model=Union[TranscriptionOutput, List[TranscriptionOutput], str], status_code=http_status.HTTP_200_OK, ) -async def remote_transcription() -> TranscribeResponse: - """Transcribe endpoint for the Remote Wordcab Transcribe API.""" +async def only_transcription( + data: TranscribeRequest, +) -> Union[TranscriptionOutput, List[TranscriptionOutput]]: + """Transcribe endpoint for the `only_transcription` asr type.""" + result: Union[TranscriptionOutput, List[TranscriptionOutput]] = ( + await asr.process_input(data) + ) + + if isinstance(result, ProcessException): + logger.error(result.message) + raise HTTPException( + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(result.message), + ) + + return result diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index 4fbd0db..2f3f9ad 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -20,8 +20,6 @@ """ASR Service module that handle all AI interactions.""" import asyncio -import functools -import os import time import traceback from abc import ABC, abstractmethod @@ -29,14 +27,22 @@ from pathlib import Path from typing import Iterable, List, Tuple, Union +import aiohttp import torch from loguru import logger -from pydantic import BaseModel -from typing_extensions import Annotated - -from wordcab_transcribe.logging import time_and_tell -from wordcab_transcribe.models import ProcessTimes, Timestamps -from wordcab_transcribe.pydantic_annotations import TorchTensorPydanticAnnotation +from pydantic import BaseModel, ConfigDict +from tensorshare import Backend, TensorShare + +from wordcab_transcribe.logging import time_and_tell, time_and_tell_async +from wordcab_transcribe.models import ( + DiarizationOutput, + DiarizationRequest, + ProcessTimes, + Timestamps, + TranscribeRequest, + TranscriptionOutput, + Utterance, +) from wordcab_transcribe.services.concurrency_services import GPUService, URLService from wordcab_transcribe.services.diarization.diarize_service import DiarizeService from wordcab_transcribe.services.post_processing_service import PostProcessingService @@ -44,8 +50,6 @@ from wordcab_transcribe.services.vad_service import VadService from wordcab_transcribe.utils import early_return, format_segments, read_audio -PydanticTorchTensor = Annotated[torch.Tensor, TorchTensorPydanticAnnotation] - class ExceptionSource(str, Enum): """Exception source enum.""" @@ -77,7 +81,9 @@ class RemoteExecution(BaseModel): class ASRTask(BaseModel): """ASR Task model.""" - audio: Union[List[PydanticTorchTensor], PydanticTorchTensor] + model_config = ConfigDict(arbitrary_types_allowed=True) + + audio: Union[torch.Tensor, List[torch.Tensor]] diarization: "DiarizationTask" duration: float multi_channel: bool @@ -94,13 +100,13 @@ class DiarizationTask(BaseModel): execution: Union[LocalExecution, RemoteExecution, None] num_speakers: int - result: Union[ProcessException, List[dict], None] = None + result: Union[ProcessException, DiarizationOutput, None] = None class PostProcessingTask(BaseModel): """Post Processing Task model.""" - result: Union[ProcessException, List[dict], None] = None + result: Union[ProcessException, List[Utterance], None] = None class TranscriptionOptions(BaseModel): @@ -121,7 +127,9 @@ class TranscriptionTask(BaseModel): execution: Union[LocalExecution, RemoteExecution] options: TranscriptionOptions - result: Union[ProcessException, List[dict], None] = None + result: Union[ + ProcessException, TranscriptionOutput, List[TranscriptionOutput], None + ] = None class ASRService(ABC): @@ -137,7 +145,6 @@ def __init__(self) -> None: ) # Do we have a GPU? If so, use it! self.num_gpus = torch.cuda.device_count() if self.device == "cuda" else 0 logger.info(f"NVIDIA GPUs available: {self.num_gpus}") - self.num_cpus = os.cpu_count() if self.num_gpus > 1 and self.device == "cuda": self.device_index = list(range(self.num_gpus)) @@ -164,8 +171,8 @@ def __init__( window_lengths: List[int], shift_lengths: List[int], multiscale_weights: List[float], - extra_languages: List[str], - extra_languages_model_paths: List[str], + extra_languages: Union[List[str], None], + extra_languages_model_paths: Union[List[str], None], transcribe_server_urls: Union[List[str], None], diarize_server_urls: Union[List[str], None], debug_mode: bool, @@ -184,9 +191,9 @@ def __init__( The shift lengths to use for diarization. multiscale_weights (List[float]): The multiscale weights to use for diarization. - extra_languages (List[str]): + extra_languages (Union[List[str], None]): The list of extra languages to support. - extra_languages_model_paths (List[str]): + extra_languages_model_paths (Union[List[str], None]): The list of paths to the extra language models. use_remote_servers (bool): Whether to use remote servers for transcription and diarization. @@ -200,21 +207,6 @@ def __init__( super().__init__() self.services: dict = { - "transcription": TranscribeService( - model_path=whisper_model, - compute_type=compute_type, - device=self.device, - device_index=self.device_index, - extra_languages=extra_languages, - extra_languages_model_paths=extra_languages_model_paths, - ), - "diarization": DiarizeService( - device=self.device, - device_index=self.device_index, - window_lengths=window_lengths, - shift_lengths=shift_lengths, - multiscale_weights=multiscale_weights, - ), "post_processing": PostProcessingService(), "vad": VadService(), } @@ -234,12 +226,27 @@ def __init__( ) else: self.use_remote_transcription = False + self.services["transcription"] = TranscribeService( + model_path=whisper_model, + compute_type=compute_type, + device=self.device, + device_index=self.device_index, + extra_languages=extra_languages, + extra_languages_model_paths=extra_languages_model_paths, + ) if diarize_server_urls is not None: self.use_remote_diarization = True self.diarization_url_handler = URLService(remote_urls=diarize_server_urls) else: self.use_remote_diarization = False + self.services["diarization"] = DiarizeService( + device=self.device, + device_index=self.device_index, + window_lengths=window_lengths, + shift_lengths=shift_lengths, + multiscale_weights=multiscale_weights, + ) self.debug_mode = debug_mode @@ -407,17 +414,10 @@ async def process_input( # noqa: C901 try: start_process_time = time.time() - transcription_task = asyncio.get_event_loop().run_in_executor( - None, - functools.partial(self.process_transcription, task, self.debug_mode), - ) - diarization_task = asyncio.get_event_loop().run_in_executor( - None, - functools.partial(self.process_diarization, task, self.debug_mode), - ) + transcription_task = self.process_transcription(task, self.debug_mode) + diarization_task = self.process_diarization(task, self.debug_mode) - await transcription_task - await diarization_task + await asyncio.gather(transcription_task, diarization_task) if isinstance(task.diarization.result, ProcessException): return task.diarization.result @@ -434,7 +434,9 @@ async def process_input( # noqa: C901 return task.transcription.result await asyncio.get_event_loop().run_in_executor( - None, functools.partial(self.process_post_processing, task) + None, + self.process_post_processing, + task, ) if isinstance(task.post_processing.result, ProcessException): @@ -453,7 +455,7 @@ async def process_input( # noqa: C901 if gpu_index is not None: self.gpu_handler.release_device(gpu_index) - def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: + async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: """ Process a task of transcription and update the task with the result. @@ -466,7 +468,7 @@ def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: """ try: if isinstance(task.transcription.execution, LocalExecution): - result, process_time = time_and_tell( + out = await time_and_tell_async( lambda: self.services["transcription"]( task.audio, model_index=task.transcription.execution.index, @@ -477,8 +479,35 @@ def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: func_name="transcription", debug_mode=debug_mode, ) + result, process_time = out + elif isinstance(task.transcription.execution, RemoteExecution): - raise NotImplementedError("Remote execution is not implemented yet.") + if isinstance(task.audio, list): + ts = [ + TensorShare.from_dict({"audio": a}, backend=Backend.TORCH) + for a in task.audio + ] + else: + ts = TensorShare.from_dict( + {"audio": task.audio}, backend=Backend.TORCH + ) + + data = TranscribeRequest( + audio=ts, + **task.transcription.options.model_dump(), + ) + out = await time_and_tell_async( + self.remote_transcription( + url=task.transcription.execution.url, + data=data, + ), + func_name="transcription", + debug_mode=debug_mode, + ) + result, process_time = out + + else: + raise NotImplementedError("No execution method specified.") except Exception as e: result = ProcessException( @@ -493,7 +522,7 @@ def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: return None - def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: + async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: """ Process a task of diarization. @@ -506,7 +535,7 @@ def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: """ try: if isinstance(task.diarization.execution, LocalExecution): - result, process_time = time_and_tell( + out = await time_and_tell_async( lambda: self.services["diarization"]( waveform=task.audio, audio_duration=task.duration, @@ -517,12 +546,33 @@ def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: func_name="diarization", debug_mode=debug_mode, ) + result, process_time = out + elif isinstance(task.diarization.execution, RemoteExecution): - raise NotImplementedError("Remote execution is not implemented yet.") + ts = TensorShare.from_dict({"audio": task.audio}, backend=Backend.TORCH) + + data = DiarizationRequest( + audio=ts, + duration=task.duration, + num_speakers=task.diarization.num_speakers, + ) + out = await time_and_tell_async( + self.remote_diarization( + url=task.diarization.execution.url, + data=data, + ), + func_name="diarization", + debug_mode=debug_mode, + ) + result, process_time = out + elif task.diarization.execution is None: result = None process_time = None + else: + raise NotImplementedError("No execution method specified.") + except Exception as e: result = ProcessException( source=ExceptionSource.diarization, @@ -548,13 +598,12 @@ def process_post_processing(self, task: ASRTask) -> None: """ try: total_post_process_time = 0 - diarization = False if task.diarization.execution is None else True if task.multi_channel: utterances, process_time = time_and_tell( - lambda: self.services[ - "post_processing" - ].multi_channel_speaker_mapping(task.transcription.result), + self.services["post_processing"].multi_channel_speaker_mapping( + task.transcription.result + ), func_name="multi_channel_speaker_mapping", debug_mode=self.debug_mode, ) @@ -562,20 +611,17 @@ def process_post_processing(self, task: ASRTask) -> None: else: formatted_segments, process_time = time_and_tell( - lambda: format_segments( - segments=task.transcription.result, - word_timestamps=True, + format_segments( + transcription_output=task.transcription.result, ), func_name="format_segments", debug_mode=self.debug_mode, ) total_post_process_time += process_time - if diarization: + if task.diarization.execution is not None: utterances, process_time = time_and_tell( - lambda: self.services[ - "post_processing" - ].single_channel_speaker_mapping( + self.services["post_processing"].single_channel_speaker_mapping( transcript_segments=formatted_segments, speaker_timestamps=task.diarization.result, word_timestamps=task.word_timestamps, @@ -588,12 +634,8 @@ def process_post_processing(self, task: ASRTask) -> None: utterances = formatted_segments final_utterances, process_time = time_and_tell( - lambda: self.services[ - "post_processing" - ].final_processing_before_returning( + self.services["post_processing"].final_processing_before_returning( utterances=utterances, - diarization=diarization, - multi_channel=task.multi_channel, offset_start=task.offset_start, timestamps_format=task.timestamps_format, word_timestamps=task.word_timestamps, @@ -616,6 +658,41 @@ def process_post_processing(self, task: ASRTask) -> None: return None + async def remote_transcription( + self, + url: str, + data: TranscribeRequest, + ) -> TranscriptionOutput: + """Remote transcription method.""" + async with aiohttp.ClientSession() as session: + async with session.post( + url=f"{url}/api/v1/transcribe", + data=data.model_dump_json(), + headers={"Content-Type": "application/json"}, + ) as response: + if response.status != 200: + raise Exception(response.status) + else: + return TranscriptionOutput(**await response.json()) + + async def remote_diarization( + self, + url: str, + data: DiarizationRequest, + ) -> DiarizationOutput: + """Remote diarization method.""" + async with aiohttp.ClientSession() as session: + async with session.post( + url=f"{url}/api/v1/diarize", + data=data.model_dump_json(), + headers={"Content-Type": "application/json"}, + ) as response: + if response.status != 200: + r = await response.json() + raise Exception(r["detail"]) + else: + return DiarizationOutput(**await response.json()) + class ASRLiveService(ASRService): """ASR Service module for live endpoints.""" @@ -672,3 +749,169 @@ async def process_input(self, data: bytes, source_lang: str) -> Iterable[dict]: finally: self.gpu_handler.release_device(gpu_index) + + +class ASRTranscriptionOnly(ASRService): + """ASR Service module for transcription-only endpoint.""" + + def __init__( + self, + whisper_model: str, + compute_type: str, + extra_languages: Union[List[str], None], + extra_languages_model_paths: Union[List[str], None], + debug_mode: bool, + ) -> None: + """Initialize the ASRTranscriptionOnly class.""" + super().__init__() + + self.transcription_service = TranscribeService( + model_path=whisper_model, + compute_type=compute_type, + device=self.device, + device_index=self.device_index, + extra_languages=extra_languages, + extra_languages_model_paths=extra_languages_model_paths, + ) + self.debug_mode = debug_mode + + async def inference_warmup(self) -> None: + """Warmup the GPU by doing one inference.""" + sample_audio = Path(__file__).parent.parent / "assets/warmup_sample.wav" + + audio, _ = read_audio(str(sample_audio)) + ts = TensorShare.from_dict({"audio": audio}, backend=Backend.TORCH) + + data = TranscribeRequest( + audio=ts, + source_lang="en", + compression_ratio_threshold=2.4, + condition_on_previous_text=True, + internal_vad=False, + log_prob_threshold=-1.0, + no_speech_threshold=0.6, + repetition_penalty=1.0, + vocab=None, + ) + + for gpu_index in self.gpu_handler.device_index: + logger.info(f"Warmup GPU {gpu_index}.") + await self.process_input(data=data) + + async def process_input( + self, data: TranscribeRequest + ) -> Union[TranscriptionOutput, List[TranscriptionOutput]]: + """ + Process the input data and return the results as a list of segments. + + Args: + data (TranscribeRequest): + The input data to process. + + Returns: + Union[TranscriptionOutput, List[TranscriptionOutput]]: + The results of the ASR pipeline. + """ + gpu_index = await self.gpu_handler.get_device() + + try: + result = self.transcription_service( + audio=data.audio, + source_lang=data.source_lang, + model_index=gpu_index, + suppress_blank=False, + word_timestamps=True, + compression_ratio_threshold=data.compression_ratio_threshold, + condition_on_previous_text=data.condition_on_previous_text, + internal_vad=data.internal_vad, + log_prob_threshold=data.log_prob_threshold, + repetition_penalty=data.repetition_penalty, + no_speech_threshold=data.no_speech_threshold, + vocab=data.vocab, + ) + + except Exception as e: + result = ProcessException( + source=ExceptionSource.transcription, + message=f"Error in transcription: {e}\n{traceback.format_exc()}", + ) + + finally: + self.gpu_handler.release_device(gpu_index) + + return result + + +class ASRDiarizationOnly(ASRService): + """ASR Service module for diarization-only endpoint.""" + + def __init__( + self, + window_lengths: List[int], + shift_lengths: List[int], + multiscale_weights: List[float], + debug_mode: bool, + ) -> None: + """Initialize the ASRDiarizationOnly class.""" + super().__init__() + + self.diarization_service = DiarizeService( + device=self.device, + device_index=self.device_index, + window_lengths=window_lengths, + shift_lengths=shift_lengths, + multiscale_weights=multiscale_weights, + ) + self.vad_service = VadService() + self.debug_mode = debug_mode + + async def inference_warmup(self) -> None: + """Warmup the GPU by doing one inference.""" + sample_audio = Path(__file__).parent.parent / "assets/warmup_sample.wav" + + audio, duration = read_audio(str(sample_audio)) + ts = TensorShare.from_dict({"audio": audio}, backend=Backend.TORCH) + + data = DiarizationRequest( + audio=ts, + duration=duration, + num_speakers=1, + ) + + for gpu_index in self.gpu_handler.device_index: + logger.info(f"Warmup GPU {gpu_index}.") + await self.process_input(data=data) + + async def process_input(self, data: DiarizationRequest) -> DiarizationOutput: + """ + Process the input data and return the results as a list of segments. + + Args: + data (DiarizationRequest): + The input data to process. + + Returns: + DiarizationOutput: + The results of the ASR pipeline. + """ + gpu_index = await self.gpu_handler.get_device() + + try: + result = self.diarization_service( + waveform=data.audio, + audio_duration=data.duration, + oracle_num_speakers=data.num_speakers, + model_index=gpu_index, + vad_service=self.vad_service, + ) + + except Exception as e: + result = ProcessException( + source=ExceptionSource.diarization, + message=f"Error in diarization: {e}\n{traceback.format_exc()}", + ) + + finally: + self.gpu_handler.release_device(gpu_index) + + return result diff --git a/src/wordcab_transcribe/services/diarization/diarize_service.py b/src/wordcab_transcribe/services/diarization/diarize_service.py index f54b353..27f5c90 100644 --- a/src/wordcab_transcribe/services/diarization/diarize_service.py +++ b/src/wordcab_transcribe/services/diarization/diarize_service.py @@ -19,10 +19,12 @@ # and limitations under the License. """Diarization Service for audio files.""" -from typing import List, NamedTuple, Tuple +from typing import List, NamedTuple, Tuple, Union import torch +from tensorshare import Backend, TensorShare +from wordcab_transcribe.models import DiarizationOutput from wordcab_transcribe.services.diarization.clustering_module import ClusteringModule from wordcab_transcribe.services.diarization.models import ( MultiscaleEmbeddingsAndTimestamps, @@ -96,25 +98,35 @@ def __init__( def __call__( self, - waveform: torch.Tensor, + waveform: Union[torch.Tensor, TensorShare], audio_duration: float, oracle_num_speakers: int, model_index: int, vad_service: VadService, - ) -> List[dict]: + ) -> DiarizationOutput: """ Run inference with the diarization model. Args: - waveform (torch.Tensor): Waveform to run inference on. - audio_duration (float): Duration of the audio file in seconds. - oracle_num_speakers (int): Number of speakers in the audio file. - model_index (int): Index of the model to use for inference. - vad_service (VadService): VAD service instance to use for Voice Activity Detection. + waveform (Union[torch.Tensor, TensorShare]): + Waveform to run inference on. + audio_duration (float): + Duration of the audio file in seconds. + oracle_num_speakers (int): + Number of speakers in the audio file. + model_index (int): + Index of the model to use for inference. + vad_service (VadService): + VAD service instance to use for Voice Activity Detection. Returns: - List[dict]: List of segments with the following keys: "start", "end", "speaker". + DiarizationOutput: + List of segments with the following keys: "start", "end", "speaker". """ + if isinstance(waveform, TensorShare): + ts = waveform.to_tensors(backend=Backend.TORCH) + waveform = ts["audio"] + vad_outputs, _ = vad_service(waveform, group_timestamps=False) if len(vad_outputs) == 0: # Empty audio @@ -157,7 +169,7 @@ def __call__( _outputs = self.get_contiguous_stamps(clustering_outputs) outputs = self.merge_stamps(_outputs) - return outputs + return DiarizationOutput(segments=outputs) @staticmethod def get_contiguous_stamps( diff --git a/src/wordcab_transcribe/services/post_processing_service.py b/src/wordcab_transcribe/services/post_processing_service.py index 8838006..23011be 100644 --- a/src/wordcab_transcribe/services/post_processing_service.py +++ b/src/wordcab_transcribe/services/post_processing_service.py @@ -19,11 +19,17 @@ # and limitations under the License. """Post-Processing Service for audio files.""" -import itertools -from typing import List, Union +from typing import List, Tuple, Union -from wordcab_transcribe.models import Timestamps -from wordcab_transcribe.utils import convert_timestamp, format_punct, is_empty_string +from wordcab_transcribe.models import ( + DiarizationOutput, + DiarizationSegment, + Timestamps, + TranscriptionOutput, + Utterance, + Word, +) +from wordcab_transcribe.utils import convert_timestamp, format_punct class PostProcessingService: @@ -35,10 +41,10 @@ def __init__(self) -> None: def single_channel_speaker_mapping( self, - transcript_segments: List[dict], - speaker_timestamps: List[dict], + transcript_segments: List[Utterance], + speaker_timestamps: DiarizationOutput, word_timestamps: bool, - ) -> List[dict]: + ) -> List[Utterance]: """Run the post-processing functions on the inputs. The postprocessing pipeline is as follows: @@ -46,16 +52,20 @@ def single_channel_speaker_mapping( 2. Group utterances of the same speaker together. Args: - transcript_segments (List[dict]): List of transcript segments. - speaker_timestamps (List[dict]): List of speaker timestamps. - word_timestamps (bool): Whether to include word timestamps. + transcript_segments (List[Utterance]): + List of transcript utterances. + speaker_timestamps (DiarizationOutput): + List of speaker timestamps. + word_timestamps (bool): + Whether to include word timestamps. Returns: - List[dict]: List of sentences with speaker mapping. + List[Utterance]: + List of utterances with speaker mapping. """ segments_with_speaker_mapping = self.segments_speaker_mapping( transcript_segments, - speaker_timestamps, + speaker_timestamps.segments, ) utterances = self.reconstruct_utterances( @@ -65,28 +75,27 @@ def single_channel_speaker_mapping( return utterances def multi_channel_speaker_mapping( - self, multi_channel_segments: List[List[dict]] - ) -> List[dict]: + self, multi_channel_segments: List[TranscriptionOutput] + ) -> TranscriptionOutput: """ Run the multi-channel post-processing functions on the inputs by merging the segments based on the timestamps. Args: - multi_channel_segments (List[dict]): List of segments from both speakers. + multi_channel_segments (List[TranscriptionOutput]): + List of segments from multi speakers. Returns: - List[dict]: List of sentences with speaker mapping. + TranscriptionOutput: List of sentences with speaker mapping. """ - words_with_speaker_mapping = [] - - for segment in list(itertools.chain.from_iterable(multi_channel_segments)): - speaker = segment["speaker"] - for word in segment["words"]: - word.update({"speaker": speaker}) - words_with_speaker_mapping.append(word) - - words_with_speaker_mapping.sort(key=lambda word: word["start"]) + words_with_speaker_mapping = [ + (segment.speaker, word) + for output in multi_channel_segments + for segment in output.segments + for word in segment.words + ] + words_with_speaker_mapping.sort(key=lambda _, word: word.start) - utterances = self.reconstruct_multi_channel_utterances( + utterances: List[Utterance] = self.reconstruct_multi_channel_utterances( words_with_speaker_mapping ) @@ -94,8 +103,8 @@ def multi_channel_speaker_mapping( def segments_speaker_mapping( self, - transcript_segments: List[dict], - speaker_timestamps: List[dict], + transcript_segments: List[Utterance], + speaker_timestamps: List[DiarizationSegment], ) -> List[dict]: """Function to map transcription and diarization results. @@ -115,11 +124,11 @@ def segments_speaker_mapping( segment_index = 0 segment_speaker_mapping = [] while segment_index < len(transcript_segments): - segment = transcript_segments[segment_index] + segment: Utterance = transcript_segments[segment_index] segment_start, segment_end, segment_text = ( - segment["start"], - segment["end"], - segment["text"], + segment.start, + segment.end, + segment.text, ) while segment_start > float(end) or abs(segment_start - float(end)) < 0.3: @@ -131,14 +140,13 @@ def segments_speaker_mapping( break if segment_end > float(end): - words = segment["words"] + words = segment.words word_index = next( ( i for i, word in enumerate(words) - if word["start"] > float(end) - or abs(word["start"] - float(end)) < 0.3 + if word.start > float(end) or abs(word.start - float(end)) < 0.3 ), None, ) @@ -147,52 +155,52 @@ def segments_speaker_mapping( _splitted_segment = segment_text.split() if word_index > 0: - _segment_to_add = { - "start": words[0]["start"], - "end": words[word_index - 1]["end"], - "text": " ".join(_splitted_segment[:word_index]), - "speaker": speaker, - "words": words[:word_index], - } + _segment_to_add = Utterance( + start=words[0].start, + end=words[word_index - 1].end, + text=" ".join(_splitted_segment[:word_index]), + speaker=speaker, + words=words[:word_index], + ) else: - _segment_to_add = { - "start": words[0]["start"], - "end": words[0]["end"], - "text": _splitted_segment[0], - "speaker": speaker, - "words": words[:1], - } + _segment_to_add = Utterance( + start=words[0].start, + end=words[0].end, + text=_splitted_segment[0], + speaker=speaker, + words=words[:1], + ) segment_speaker_mapping.append(_segment_to_add) transcript_segments.insert( segment_index + 1, - { - "start": words[word_index]["start"], - "end": segment_end, - "text": " ".join(_splitted_segment[word_index:]), - "words": words[word_index:], - }, + Utterance( + start=words[word_index].start, + end=segment_end, + text=" ".join(_splitted_segment[word_index:]), + words=words[word_index:], + ), ) else: segment_speaker_mapping.append( - { - "start": segment_start, - "end": segment_end, - "text": segment_text, - "speaker": speaker, - "words": words, - } + Utterance( + start=segment_start, + end=segment_end, + text=segment_text, + speaker=speaker, + words=words, + ) ) else: segment_speaker_mapping.append( - { - "start": segment_start, - "end": segment_end, - "text": segment_text, - "speaker": speaker, - "words": segment["words"], - } + Utterance( + start=segment_start, + end=segment_end, + text=segment_text, + speaker=speaker, + words=segment.words, + ) ) segment_index += 1 @@ -201,23 +209,26 @@ def segments_speaker_mapping( def reconstruct_utterances( self, - transcript_words: List[dict], + transcript_segments: List[Utterance], word_timestamps: bool, - ) -> List[dict]: + ) -> List[Utterance]: """ Reconstruct the utterances based on the speaker mapping. Args: - transcript_words (List[dict]): List of transcript words. - word_timestamps (bool): Whether to include word timestamps. + transcript_words (List[Utterance]): + List of transcript segments. + word_timestamps (bool): + Whether to include word timestamps. Returns: - List[dict]: List of sentences with speaker mapping. + List[Utterance]: + List of sentences with speaker mapping. """ start_t0, end_t0, speaker_t0 = ( - transcript_words[0]["start"], - transcript_words[0]["end"], - transcript_words[0]["speaker"], + transcript_segments[0].start, + transcript_segments[0].end, + transcript_segments[0].speaker, ) previous_speaker = speaker_t0 @@ -231,12 +242,12 @@ def reconstruct_utterances( current_sentence["words"] = [] sentences = [] - for segment in transcript_words: - text, speaker = segment["text"], segment["speaker"] - start_t, end_t = segment["start"], segment["end"] + for segment in transcript_segments: + text, speaker = segment.text, segment.speaker + start_t, end_t = segment.start, segment.end if speaker != previous_speaker: - sentences.append(current_sentence) + sentences.append(Utterance(**current_sentence)) current_sentence = { "speaker": speaker, "start": start_t, @@ -251,31 +262,29 @@ def reconstruct_utterances( current_sentence["text"] += text + " " previous_speaker = speaker if word_timestamps: - current_sentence["words"].extend(segment["words"]) + current_sentence["words"].extend(segment.words) # Catch the last sentence - sentences.append(current_sentence) + sentences.append(Utterance(**current_sentence)) return sentences def reconstruct_multi_channel_utterances( self, - transcript_words: List[dict], - ) -> List[dict]: + transcript_words: List[Tuple[int, Word]], + ) -> List[Utterance]: """ Reconstruct multi-channel utterances based on the speaker mapping. Args: - transcript_words (List[dict]): List of transcript words. + transcript_words (List[Tuple[int, Word]]): + List of tuples containing the speaker and the word. Returns: - List[dict]: List of sentences with speaker mapping. + List[Utterance]: List of sentences with speaker mapping. """ - start_t0, end_t0, speaker_t0 = ( - transcript_words[0]["start"], - transcript_words[0]["end"], - transcript_words[0]["speaker"], - ) + speaker_t0, word = transcript_words[0] + start_t0, end_t0 = word.start, word.end previous_speaker = speaker_t0 current_sentence = { @@ -287,9 +296,8 @@ def reconstruct_multi_channel_utterances( } sentences = [] - for word in transcript_words: - text, speaker = word["word"], word["speaker"] - start_t, end_t = word["start"], word["end"] + for speaker, word in transcript_words: + start_t, end_t, text = word.start, word.end, word.word if speaker != previous_speaker: sentences.append(current_sentence) @@ -313,27 +321,21 @@ def reconstruct_multi_channel_utterances( for sentence in sentences: sentence["text"] = sentence["text"].strip() - return sentences + return [Utterance(**sentence) for sentence in sentences] def final_processing_before_returning( self, - utterances: List[dict], - diarization: bool, - multi_channel: bool, + utterances: List[Utterance], offset_start: Union[float, None], timestamps_format: Timestamps, word_timestamps: bool, - ) -> List[dict]: + ) -> List[Utterance]: """ Do final processing before returning the utterances to the API. Args: - utterances (List[dict]): + utterances (List[Utterance]): List of utterances. - diarization (bool): - Whether diarization is enabled. - multi_channel (bool): - Whether multi-channel is enabled. offset_start (Union[float, None]): Offset start. timestamps_format (Timestamps): @@ -342,29 +344,22 @@ def final_processing_before_returning( Whether to include word timestamps. Returns: - List[dict]: List of utterances with final processing. + List[Utterance]: + List of utterances with final processing. """ if offset_start is not None: offset_start = float(offset_start) else: offset_start = 0.0 - include_speaker = diarization or multi_channel - - _utterances = [ - { - "text": format_punct(utterance["text"]), - "start": convert_timestamp( - (utterance["start"] + offset_start), timestamps_format - ), - "end": convert_timestamp( - (utterance["end"] + offset_start), timestamps_format - ), - "speaker": int(utterance["speaker"]) if include_speaker else None, - "words": utterance["words"] if word_timestamps else [], - } - for utterance in utterances - if not is_empty_string(utterance["text"]) - ] + for utterance in utterances: + utterance.text = format_punct(utterance.text) + utterance.start = convert_timestamp( + (utterance.start + offset_start), timestamps_format + ) + utterance.end = convert_timestamp( + (utterance.end + offset_start), timestamps_format + ) + utterance.words = utterance.words if word_timestamps else None - return _utterances + return utterances diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index 59cac1e..eba00f4 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -24,6 +24,9 @@ import torch from faster_whisper import WhisperModel from loguru import logger +from tensorshare import Backend, TensorShare + +from wordcab_transcribe.models import TranscriptionOutput class FasterWhisperModel(NamedTuple): @@ -79,7 +82,14 @@ def __init__( def __call__( self, - audio: Union[str, torch.Tensor, List[str], List[torch.Tensor]], + audio: Union[ + str, + torch.Tensor, + TensorShare, + List[str], + List[torch.Tensor], + List[TensorShare], + ], source_lang: str, model_index: int, suppress_blank: bool = False, @@ -91,12 +101,12 @@ def __call__( log_prob_threshold: float = -1.0, no_speech_threshold: float = 0.6, condition_on_previous_text: bool = True, - ) -> Union[List[dict], List[List[dict]]]: + ) -> Union[TranscriptionOutput, List[TranscriptionOutput]]: """ Run inference with the transcribe model. Args: - audio (Union[str, torch.Tensor, List[str], List[torch.Tensor]]): + audio (Union[str, torch.Tensor, TensorShare, List[str], List[torch.Tensor], List[TensorShare]]): Audio file path or audio tensor. If a tuple is passed, the task is assumed to be a multi_channel task and the list of audio files or tensors is passed. source_lang (str): @@ -126,8 +136,8 @@ def __call__( to getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. Returns: - Union[List[dict], List[List[dict]]]: List of transcriptions. If the task is a multi_channel task, - a list of lists is returned. + Union[TranscriptionOutput, List[TranscriptionOutput]]: + Transcription output. If the task is a multi_channel task, a list of TranscriptionOutput is returned. """ # Extra language models are disabled until we can handle an index mapping # if ( @@ -172,6 +182,9 @@ def __call__( if not isinstance(audio, list): if isinstance(audio, torch.Tensor): audio = audio.numpy() + elif isinstance(audio, TensorShare): + ts = audio.to_tensors(backend=Backend.NUMPY) + audio = ts["audio"] segments, _ = self.model.transcribe( audio, @@ -213,7 +226,8 @@ def __call__( vad_filter=False if internal_vad else True, ) - outputs = [segment._asdict() for segment in segments] + _outputs = [segment._asdict() for segment in segments] + outputs = TranscriptionOutput(segments=_outputs) else: outputs = [] @@ -286,7 +300,7 @@ def live_transcribe( def multi_channel( self, - audio: Union[str, torch.Tensor], + audio: Union[str, torch.Tensor, TensorShare], source_lang: str, speaker_id: int, suppress_blank: bool = False, @@ -298,12 +312,12 @@ def multi_channel( no_speech_threshold: float = 0.6, condition_on_previous_text: bool = False, prompt: Optional[str] = None, - ) -> List[dict]: + ) -> TranscriptionOutput: """ Transcribe an audio file using the faster-whisper original pipeline. Args: - audio (Union[str, torch.Tensor]): Audio file path or loaded audio. + audio (Union[str, torch.Tensor, TensorShare]): Audio file path or loaded audio. source_lang (str): Language of the audio file. speaker_id (int): Speaker ID used in the diarization. suppress_blank (bool): @@ -328,15 +342,18 @@ def multi_channel( prompt (Optional[str]): Initial prompt to use for the generation. Returns: - List[dict]: List of transcribed segments. + TranscriptionOutput: Transcription output. """ if isinstance(audio, torch.Tensor): - audio = audio.numpy() + _audio = audio.numpy() + elif isinstance(audio, TensorShare): + ts = audio.to_tensors(backend=Backend.NUMPY) + _audio = ts["audio"] - final_transcript = [] + final_segments = [] segments, _ = self.model.transcribe( - audio, + _audio, language=source_lang, initial_prompt=prompt, repetition_penalty=repetition_penalty, @@ -378,6 +395,6 @@ def multi_channel( segment_dict["start"] = segment_dict["words"][0]["start"] segment_dict["end"] = segment_dict["words"][-1]["end"] - final_transcript.append(segment_dict) + final_segments.append(segment_dict) - return final_transcript + return TranscriptionOutput(segments=final_segments) diff --git a/src/wordcab_transcribe/utils.py b/src/wordcab_transcribe/utils.py index 9c95c44..fcb752a 100644 --- a/src/wordcab_transcribe/utils.py +++ b/src/wordcab_transcribe/utils.py @@ -38,7 +38,12 @@ if TYPE_CHECKING: from fastapi import UploadFile -from wordcab_transcribe.models import Timestamps +from wordcab_transcribe.models import ( + Timestamps, + TranscriptionOutput, + Utterance, + Word, +) # pragma: no cover @@ -404,38 +409,34 @@ def format_punct(text: str): return text.strip() -def format_segments(segments: list, word_timestamps: bool) -> List[dict]: +def format_segments(transcription_output: TranscriptionOutput) -> List[Utterance]: """ Format the segments to a list of dicts with start, end and text keys. Optionally include word timestamps. Args: - segments (list): List of segments. + transcription_output (TranscriptionOutput): List of segments. word_timestamps (bool): Whether to include word timestamps. Returns: - list: List of dicts with start, end and word keys. - """ - formatted_segments = [] - - for segment in segments: - segment_dict = {} - - segment_dict["start"] = segment["start"] - segment_dict["end"] = segment["end"] - segment_dict["text"] = segment["text"].strip() - if word_timestamps: - _words = [ - { - "word": word.word.strip(), - "start": word.start, - "end": word.end, - "score": round(word.probability, 2), - } - for word in segment["words"] - ] - segment_dict["words"] = _words - - formatted_segments.append(segment_dict) + List[Utterance]: List of formatted segments. + """ + formatted_segments = [ + Utterance( + text=segment.text, + start=segment.start, + end=segment.end, + words=[ + Word( + word=word.word, + start=word.start, + end=word.end, + probability=word.probability, + ) + for word in segment.words + ], + ) + for segment in transcription_output.segments + ] return formatted_segments @@ -546,15 +547,17 @@ def read_audio( Tuple[torch.Tensor, float]: The audio tensor and the audio duration. """ if isinstance(audio, str): - wav, sr = torchaudio.load( - audio, - ) + wav, sr = torchaudio.load(audio) elif isinstance(audio, bytes): with io.BytesIO(audio) as buffer: wav, sr = sf.read( buffer, format="RAW", channels=1, samplerate=16000, subtype="PCM_16" ) wav = torch.from_numpy(wav).unsqueeze(0) + else: + raise ValueError( + f"Invalid audio type. Must be either str or bytes, got: {type(audio)}." + ) if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True) diff --git a/tests/test_config.py b/tests/test_config.py index 789d7ad..dc5b204 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -47,11 +47,7 @@ def default_settings() -> OrderedDict: shift_lengths=[0.75, 0.625, 0.5, 0.375, 0.25], multiscale_weights=[1.0, 1.0, 1.0, 1.0, 1.0], asr_type="async", - audio_file_endpoint=True, - audio_url_endpoint=True, cortex_endpoint=True, - youtube_endpoint=True, - live_endpoint=False, username="admin", password="admin", openssl_key="0123456789abcdefghijklmnopqrstuvwyz", @@ -85,12 +81,7 @@ def test_config() -> None: assert settings.multiscale_weights == [1.0, 1.0, 1.0, 1.0, 1.0] assert settings.asr_type == "async" - - assert settings.audio_file_endpoint is True - assert settings.audio_url_endpoint is True assert settings.cortex_endpoint is True - assert settings.youtube_endpoint is True - assert settings.live_endpoint is False assert settings.username == "admin" # noqa: S105 assert settings.password == "admin" # noqa: S105 @@ -152,15 +143,3 @@ def test_access_token_expire_minutes_validator(default_settings: dict) -> None: default_settings["access_token_expire_minutes"] = -1 with pytest.raises(ValueError): Settings(**default_settings) - - -def test_post_init(default_settings: dict) -> None: - """Test post init.""" - wrong_endpoint = default_settings.copy() - wrong_endpoint["audio_file_endpoint"] = False - wrong_endpoint["audio_url_endpoint"] = False - wrong_endpoint["cortex_endpoint"] = False - wrong_endpoint["live_endpoint"] = False - wrong_endpoint["youtube_endpoint"] = False - with pytest.raises(ValueError): - Settings(**wrong_endpoint) diff --git a/tests/test_models.py b/tests/test_models.py index e1447b0..080a78e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -59,20 +59,6 @@ def test_timestamps() -> None: assert Timestamps.hour_minute_second == "hms" -def test_word() -> None: - """Test the Word model.""" - word = Word( - word="test", - start=0.0, - end=1.0, - score=0.9, - ) - assert word.word == "test" - assert word.start == 0.0 - assert word.end == 1.0 - assert word.score == 0.9 - - def test_utterance() -> None: """Test the Utterance model.""" utterance = Utterance( @@ -85,25 +71,25 @@ def test_utterance() -> None: word="This", start=0.0, end=1.0, - score=0.9, + probability=0.9, ), Word( word="is", start=1.0, end=2.0, - score=0.75, + probability=0.75, ), Word( word="a", start=2.0, end=3.0, - score=0.8, + probability=0.8, ), Word( word="test.", start=3.0, end=4.0, - score=0.85, + probability=0.85, ), ], ) @@ -117,25 +103,25 @@ def test_utterance() -> None: word="This", start=0.0, end=1.0, - score=0.9, + probability=0.9, ), Word( word="is", start=1.0, end=2.0, - score=0.75, + probability=0.75, ), Word( word="a", start=2.0, end=3.0, - score=0.8, + probability=0.8, ), Word( word="test.", start=3.0, end=4.0, - score=0.85, + probability=0.85, ), ] assert isinstance(utterance.words[0], Word)