From 147e2956920f4d8690bc52bfe730bbae983f0b8f Mon Sep 17 00:00:00 2001 From: Thomas Chaigneau Date: Thu, 5 Oct 2023 09:29:56 +0200 Subject: [PATCH] Create `add` and `remove` url endpoints (#263) * create add and remove url endpoints * add get_url + update endpoints --- src/wordcab_transcribe/models.py | 9 +- src/wordcab_transcribe/router/v1/endpoints.py | 9 + .../router/v1/manage_remote_url.py | 97 +++++++++++ .../services/asr_service.py | 161 ++++++++++++++++-- .../services/concurrency_services.py | 69 ++++++-- 5 files changed, 317 insertions(+), 28 deletions(-) create mode 100644 src/wordcab_transcribe/router/v1/manage_remote_url.py diff --git a/src/wordcab_transcribe/models.py b/src/wordcab_transcribe/models.py index cf1df0d..86a3e93 100644 --- a/src/wordcab_transcribe/models.py +++ b/src/wordcab_transcribe/models.py @@ -23,7 +23,7 @@ from typing import List, Literal, NamedTuple, Optional, Union from faster_whisper.transcribe import Segment -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, HttpUrl, field_validator from tensorshare import TensorShare @@ -484,6 +484,13 @@ class Config: } +class UrlSchema(BaseModel): + """Request model for the add_url endpoint.""" + + task: Literal["transcription", "diarization"] + url: HttpUrl + + class DiarizationSegment(NamedTuple): """Diarization segment model for the API.""" diff --git a/src/wordcab_transcribe/router/v1/endpoints.py b/src/wordcab_transcribe/router/v1/endpoints.py index 6ff8e6d..fad4397 100644 --- a/src/wordcab_transcribe/router/v1/endpoints.py +++ b/src/wordcab_transcribe/router/v1/endpoints.py @@ -30,6 +30,9 @@ ) from wordcab_transcribe.router.v1.diarize_endpoint import router as diarize_router from wordcab_transcribe.router.v1.live_endpoint import router as live_router +from wordcab_transcribe.router.v1.manage_remote_url import ( + router as manage_remote_url_router, +) from wordcab_transcribe.router.v1.transcribe_endpoint import router as transcribe_router from wordcab_transcribe.router.v1.youtube_endpoint import router as youtube_router @@ -43,10 +46,16 @@ live_routers = (live_router, "/live", "live") transcribe_routers = (transcribe_router, "/transcribe", "transcription") diarize_routers = (diarize_router, "/diarize", "diarization") +manage_remote_url_routers = ( + manage_remote_url_router, + "/url", + "remote-url", +) routers = [] if settings.asr_type == "async": routers.extend(async_routers) + routers.append(manage_remote_url_routers) elif settings.asr_type == "live": routers.append(live_routers) elif settings.asr_type == "only_transcription": diff --git a/src/wordcab_transcribe/router/v1/manage_remote_url.py b/src/wordcab_transcribe/router/v1/manage_remote_url.py new file mode 100644 index 0000000..536f258 --- /dev/null +++ b/src/wordcab_transcribe/router/v1/manage_remote_url.py @@ -0,0 +1,97 @@ +# Copyright 2023 The Wordcab Team. All rights reserved. +# +# Licensed under the Wordcab Transcribe License 0.1 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Wordcab/wordcab-transcribe/blob/main/LICENSE +# +# Except as expressly provided otherwise herein, and to the fullest +# extent permitted by law, Licensor provides the Software (and each +# Contributor provides its Contributions) AS IS, and Licensor +# disclaims all warranties or guarantees of any kind, express or +# implied, whether arising under any law or from any usage in trade, +# or otherwise including but not limited to the implied warranties +# of merchantability, non-infringement, quiet enjoyment, fitness +# for a particular purpose, or otherwise. +# +# See the License for the specific language governing permissions +# and limitations under the License. +"""Add Remote URL endpoint for remote transcription or diarization.""" + +from typing import List, Union + +from fastapi import APIRouter, HTTPException +from fastapi import status as http_status +from loguru import logger +from pydantic import HttpUrl +from typing_extensions import Literal + +from wordcab_transcribe.dependencies import asr +from wordcab_transcribe.models import UrlSchema +from wordcab_transcribe.services.asr_service import ExceptionSource, ProcessException + +router = APIRouter() + + +@router.get( + "", + response_model=Union[List[HttpUrl], str], + status_code=http_status.HTTP_200_OK, +) +async def get_url(task: Literal["transcription", "diarization"]) -> List[HttpUrl]: + """Get Remote URL endpoint for remote transcription or diarization.""" + result: List[UrlSchema] = await asr.get_url(task) + + if isinstance(result, ProcessException): + logger.error(result.message) + if result.source == ExceptionSource.get_url: + raise HTTPException( + status_code=http_status.HTTP_405_METHOD_NOT_ALLOWED, + detail=str(result.message), + ) + else: + raise HTTPException( + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(result.message), + ) + + return result + + +@router.post( + "/add", + response_model=Union[UrlSchema, str], + status_code=http_status.HTTP_200_OK, +) +async def add_url(data: UrlSchema) -> UrlSchema: + """Add Remote URL endpoint for remote transcription or diarization.""" + result: UrlSchema = await asr.add_url(data) + + if isinstance(result, ProcessException): + logger.error(result.message) + raise HTTPException( + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(result.message), + ) + + return result + + +@router.post( + "/remove", + response_model=Union[UrlSchema, str], + status_code=http_status.HTTP_200_OK, +) +async def remove_url(data: UrlSchema) -> UrlSchema: + """Remove Remote URL endpoint for remote transcription or diarization.""" + result: UrlSchema = await asr.remove_url(data) + + if isinstance(result, ProcessException): + logger.error(result.message) + raise HTTPException( + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(result.message), + ) + + return result diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index 2f3f9ad..a25c047 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -23,6 +23,7 @@ import time import traceback from abc import ABC, abstractmethod +from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Iterable, List, Tuple, Union @@ -32,6 +33,7 @@ from loguru import logger from pydantic import BaseModel, ConfigDict from tensorshare import Backend, TensorShare +from typing_extensions import Literal from wordcab_transcribe.logging import time_and_tell, time_and_tell_async from wordcab_transcribe.models import ( @@ -41,6 +43,7 @@ Timestamps, TranscribeRequest, TranscriptionOutput, + UrlSchema, Utterance, ) from wordcab_transcribe.services.concurrency_services import GPUService, URLService @@ -54,8 +57,11 @@ class ExceptionSource(str, Enum): """Exception source enum.""" + add_url = "add_url" diarization = "diarization" + get_url = "get_url" post_processing = "post_processing" + remove_url = "remove_url" transcription = "transcription" @@ -132,6 +138,16 @@ class TranscriptionTask(BaseModel): ] = None +@dataclass +class ServiceHandler: + """Services handler model.""" + + diarization: Union[DiarizeService, None] = None + post_processing: PostProcessingService = PostProcessingService() + transcription: Union[TranscribeService, None] = None + vad: VadService = VadService() + + class ASRService(ABC): """Base ASR Service module that handle all AI interactions and batch processing.""" @@ -206,10 +222,7 @@ def __init__( """ super().__init__() - self.services: dict = { - "post_processing": PostProcessingService(), - "vad": VadService(), - } + self.services: ServiceHandler = ServiceHandler() self.dual_channel_transcribe_options: dict = { "beam_size": 5, "patience": 1, @@ -220,13 +233,21 @@ def __init__( } if transcribe_server_urls is not None: + logger.info( + "You provided URLs for remote transcription server, no local model will" + " be used." + ) self.use_remote_transcription = True self.transcription_url_handler = URLService( remote_urls=transcribe_server_urls ) else: + logger.info( + "You did not provide URLs for remote transcription server, local model" + " will be used." + ) self.use_remote_transcription = False - self.services["transcription"] = TranscribeService( + self.services.transcription = TranscribeService( model_path=whisper_model, compute_type=compute_type, device=self.device, @@ -236,11 +257,19 @@ def __init__( ) if diarize_server_urls is not None: + logger.info( + "You provided URLs for remote diarization server, no local model will" + " be used." + ) self.use_remote_diarization = True self.diarization_url_handler = URLService(remote_urls=diarize_server_urls) else: + logger.info( + "You did not provide URLs for remote diarization server, local model" + " will be used." + ) self.use_remote_diarization = False - self.services["diarization"] = DiarizeService( + self.services.diarization = DiarizeService( device=self.device, device_index=self.device_index, window_lengths=window_lengths, @@ -469,7 +498,7 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: try: if isinstance(task.transcription.execution, LocalExecution): out = await time_and_tell_async( - lambda: self.services["transcription"]( + lambda: self.services.transcription( task.audio, model_index=task.transcription.execution.index, suppress_blank=False, @@ -536,12 +565,12 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: try: if isinstance(task.diarization.execution, LocalExecution): out = await time_and_tell_async( - lambda: self.services["diarization"]( + lambda: self.services.diarization( waveform=task.audio, audio_duration=task.duration, oracle_num_speakers=task.diarization.num_speakers, model_index=task.diarization.execution.index, - vad_service=self.services["vad"], + vad_service=self.services.vad, ), func_name="diarization", debug_mode=debug_mode, @@ -601,7 +630,7 @@ def process_post_processing(self, task: ASRTask) -> None: if task.multi_channel: utterances, process_time = time_and_tell( - self.services["post_processing"].multi_channel_speaker_mapping( + self.services.post_processing.multi_channel_speaker_mapping( task.transcription.result ), func_name="multi_channel_speaker_mapping", @@ -621,7 +650,7 @@ def process_post_processing(self, task: ASRTask) -> None: if task.diarization.execution is not None: utterances, process_time = time_and_tell( - self.services["post_processing"].single_channel_speaker_mapping( + self.services.post_processing.single_channel_speaker_mapping( transcript_segments=formatted_segments, speaker_timestamps=task.diarization.result, word_timestamps=task.word_timestamps, @@ -634,7 +663,7 @@ def process_post_processing(self, task: ASRTask) -> None: utterances = formatted_segments final_utterances, process_time = time_and_tell( - self.services["post_processing"].final_processing_before_returning( + self.services.post_processing.final_processing_before_returning( utterances=utterances, offset_start=task.offset_start, timestamps_format=task.timestamps_format, @@ -693,6 +722,114 @@ async def remote_diarization( else: return DiarizationOutput(**await response.json()) + async def get_url( + self, task: Literal["transcription", "diarization"] + ) -> Union[List[str], ProcessException]: + """Get the list of remote URLs.""" + try: + if task == "transcription": + # Case 1: We are not using remote transcription + if self.use_remote_transcription is False: + return ProcessException( + source=ExceptionSource.get_url, + message="You are not using remote transcription.", + ) + # Case 2: We are using remote transcription + else: + return self.transcription_url_handler.get_urls() + + elif task == "diarization": + # Case 1: We are not using remote diarization + if self.use_remote_diarization is False: + return ProcessException( + source=ExceptionSource.get_url, + message="You are not using remote diarization.", + ) + # Case 2: We are using remote diarization + else: + return self.diarization_url_handler.get_urls() + + else: + raise ValueError(f"{task} is not a valid task.") + + except Exception as e: + return ProcessException( + source=ExceptionSource.get_url, + message=f"Error in getting URL: {e}\n{traceback.format_exc()}", + ) + + async def add_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]: + """Add a remote URL to the list of URLs.""" + try: + if data.task == "transcription": + # Case 1: We are not using remote transcription yet + if self.use_remote_transcription is False: + self.use_remote_transcription = True + self.transcription_url_handler = URLService( + remote_urls=[str(data.url)] + ) + # Case 2: We are already using remote transcription + else: + await self.transcription_url_handler.add_url(data.url) + + elif data.task == "diarization": + # Case 1: We are not using remote diarization yet + if self.use_remote_diarization is False: + self.use_remote_diarization = True + self.diarization_url_handler = URLService( + remote_urls=[str(data.url)] + ) + # Case 2: We are already using remote diarization + else: + await self.diarization_url_handler.add_url(data.url) + + else: + raise ValueError(f"{data.task} is not a valid task.") + + except Exception as e: + return ProcessException( + source=ExceptionSource.add_url, + message=f"Error in adding URL: {e}\n{traceback.format_exc()}", + ) + + return data + + async def remove_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]: + """Remove a remote URL from the list of URLs.""" + try: + if data.task == "transcription": + # Case 1: We are not using remote transcription + if self.use_remote_transcription is False: + raise ValueError("You are not using remote transcription.") + # Case 2: We are using remote transcription + else: + await self.transcription_url_handler.remove_url(str(data.url)) + if self.transcription_url_handler.get_queue_size() == 0: + # TODO: Add a way to switch back to local transcription + pass + + elif data.task == "diarization": + # Case 1: We are not using remote diarization + if self.use_remote_diarization is False: + raise ValueError("You are not using remote diarization.") + # Case 2: We are using remote diarization + else: + await self.diarization_url_handler.remove_url(str(data.url)) + if self.diarization_url_handler.get_queue_size() == 0: + # TODO: Add a way to switch back to local diarization + pass + + else: + raise ValueError(f"{data.task} is not a valid task.") + + return data + + except Exception as e: + return ProcessException( + source=ExceptionSource.remove_url, + message=f"Error in removing URL: {e}\n{traceback.format_exc()}", + ) + class ASRLiveService(ASRService): """ASR Service module for live endpoints.""" diff --git a/src/wordcab_transcribe/services/concurrency_services.py b/src/wordcab_transcribe/services/concurrency_services.py index b604db0..0e1d133 100644 --- a/src/wordcab_transcribe/services/concurrency_services.py +++ b/src/wordcab_transcribe/services/concurrency_services.py @@ -77,14 +77,31 @@ def __init__(self, remote_urls: List[str]) -> None: remote_urls (List[str]): List of remote URLs to use. """ self.remote_urls: List[str] = remote_urls + self._init_queue() - # If there is only one URL, we don't need to use a queue - if len(self.remote_urls) == 1: - self.queue = None - else: - self.queue = asyncio.Queue(maxsize=len(self.remote_urls)) - for url in self.remote_urls: - self.queue.put_nowait(url) + def _init_queue(self) -> None: + """Initialize the queue with the available URLs.""" + self.queue = asyncio.Queue(maxsize=len(self.remote_urls)) + for url in self.remote_urls: + self.queue.put_nowait(url) + + def get_queue_size(self) -> int: + """ + Get the current queue size. + + Returns: + int: Current queue size. + """ + return self.queue.qsize() + + def get_urls(self) -> List[str]: + """ + Get the list of available URLs. + + Returns: + List[str]: List of available URLs. + """ + return self.remote_urls async def next_url(self) -> str: """ @@ -93,13 +110,35 @@ async def next_url(self) -> str: Returns: str: Next available URL. """ - if self.queue is None: - return self.remote_urls[0] + url = self.queue.get_nowait() + # Unlike GPU we don't want to block remote ASR requests. + # So we re-insert the URL back into the queue after getting it. + self.queue.put_nowait(url) - else: - url = self.queue.get_nowait() - # Unlike GPU we don't want to block remote ASR requests. - # So we re-insert the URL back into the queue after getting it. - self.queue.put_nowait(url) + return url + + async def add_url(self, url: str) -> None: + """ + Add a URL to the pool of available URLs. + + Args: + url (str): URL to add to the queue. + """ + if url not in self.remote_urls: + self.remote_urls.append(url) + + # Re-initialize the queue with the new URL. + self._init_queue() + + async def remove_url(self, url: str) -> None: + """ + Remove a URL from the pool of available URLs. + + Args: + url (str): URL to remove from the queue. + """ + if url in self.remote_urls: + self.remote_urls.remove(url) - return url + # Re-initialize the queue without the removed URL. + self._init_queue()