diff --git a/src/wordcab_transcribe/models.py b/src/wordcab_transcribe/models.py index 4d0248b..86a3e93 100644 --- a/src/wordcab_transcribe/models.py +++ b/src/wordcab_transcribe/models.py @@ -484,13 +484,6 @@ class Config: } -class UrlRequest(BaseModel): - """Request model for the add_url endpoint.""" - - task: Literal["transcription", "diarization"] - url: str - - class UrlSchema(BaseModel): """Request model for the add_url endpoint.""" diff --git a/src/wordcab_transcribe/router/v1/endpoints.py b/src/wordcab_transcribe/router/v1/endpoints.py index 48f7b5b..fad4397 100644 --- a/src/wordcab_transcribe/router/v1/endpoints.py +++ b/src/wordcab_transcribe/router/v1/endpoints.py @@ -48,7 +48,7 @@ diarize_routers = (diarize_router, "/diarize", "diarization") manage_remote_url_routers = ( manage_remote_url_router, - "/remote-url", + "/url", "remote-url", ) diff --git a/src/wordcab_transcribe/router/v1/manage_remote_url.py b/src/wordcab_transcribe/router/v1/manage_remote_url.py index 49c26e9..536f258 100644 --- a/src/wordcab_transcribe/router/v1/manage_remote_url.py +++ b/src/wordcab_transcribe/router/v1/manage_remote_url.py @@ -19,36 +19,54 @@ # and limitations under the License. """Add Remote URL endpoint for remote transcription or diarization.""" -from typing import Union +from typing import List, Union from fastapi import APIRouter, HTTPException from fastapi import status as http_status from loguru import logger -from pydantic import ValidationError +from pydantic import HttpUrl +from typing_extensions import Literal from wordcab_transcribe.dependencies import asr -from wordcab_transcribe.models import UrlRequest, UrlSchema -from wordcab_transcribe.services.asr_service import ProcessException +from wordcab_transcribe.models import UrlSchema +from wordcab_transcribe.services.asr_service import ExceptionSource, ProcessException router = APIRouter() +@router.get( + "", + response_model=Union[List[HttpUrl], str], + status_code=http_status.HTTP_200_OK, +) +async def get_url(task: Literal["transcription", "diarization"]) -> List[HttpUrl]: + """Get Remote URL endpoint for remote transcription or diarization.""" + result: List[UrlSchema] = await asr.get_url(task) + + if isinstance(result, ProcessException): + logger.error(result.message) + if result.source == ExceptionSource.get_url: + raise HTTPException( + status_code=http_status.HTTP_405_METHOD_NOT_ALLOWED, + detail=str(result.message), + ) + else: + raise HTTPException( + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=str(result.message), + ) + + return result + + @router.post( "/add", response_model=Union[UrlSchema, str], status_code=http_status.HTTP_200_OK, ) -async def add_url(data: UrlRequest) -> UrlSchema: +async def add_url(data: UrlSchema) -> UrlSchema: """Add Remote URL endpoint for remote transcription or diarization.""" - try: - _data = UrlSchema(task=data.task, url=data.url) - except ValidationError as e: - raise HTTPException( # noqa: B904 - status_code=http_status.HTTP_400_BAD_REQUEST, - detail=f"Your request is invalid: {e}", - ) - - result: UrlSchema = await asr.add_url(_data) + result: UrlSchema = await asr.add_url(data) if isinstance(result, ProcessException): logger.error(result.message) @@ -65,17 +83,9 @@ async def add_url(data: UrlRequest) -> UrlSchema: response_model=Union[UrlSchema, str], status_code=http_status.HTTP_200_OK, ) -async def remove_url(data: UrlRequest) -> UrlSchema: +async def remove_url(data: UrlSchema) -> UrlSchema: """Remove Remote URL endpoint for remote transcription or diarization.""" - try: - _data = UrlSchema(task=data.task, url=data.url) - except ValidationError as e: - raise HTTPException( # noqa: B904 - status_code=http_status.HTTP_400_BAD_REQUEST, - detail=f"Your request is invalid: {e}", - ) - - result: UrlSchema = await asr.remove_url(_data) + result: UrlSchema = await asr.remove_url(data) if isinstance(result, ProcessException): logger.error(result.message) diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index c23ecb8..a25c047 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -33,6 +33,7 @@ from loguru import logger from pydantic import BaseModel, ConfigDict from tensorshare import Backend, TensorShare +from typing_extensions import Literal from wordcab_transcribe.logging import time_and_tell, time_and_tell_async from wordcab_transcribe.models import ( @@ -58,6 +59,7 @@ class ExceptionSource(str, Enum): add_url = "add_url" diarization = "diarization" + get_url = "get_url" post_processing = "post_processing" remove_url = "remove_url" transcription = "transcription" @@ -720,26 +722,66 @@ async def remote_diarization( else: return DiarizationOutput(**await response.json()) - async def add_url(self, data: UrlSchema) -> UrlSchema: + async def get_url( + self, task: Literal["transcription", "diarization"] + ) -> Union[List[str], ProcessException]: + """Get the list of remote URLs.""" + try: + if task == "transcription": + # Case 1: We are not using remote transcription + if self.use_remote_transcription is False: + return ProcessException( + source=ExceptionSource.get_url, + message="You are not using remote transcription.", + ) + # Case 2: We are using remote transcription + else: + return self.transcription_url_handler.get_urls() + + elif task == "diarization": + # Case 1: We are not using remote diarization + if self.use_remote_diarization is False: + return ProcessException( + source=ExceptionSource.get_url, + message="You are not using remote diarization.", + ) + # Case 2: We are using remote diarization + else: + return self.diarization_url_handler.get_urls() + + else: + raise ValueError(f"{task} is not a valid task.") + + except Exception as e: + return ProcessException( + source=ExceptionSource.get_url, + message=f"Error in getting URL: {e}\n{traceback.format_exc()}", + ) + + async def add_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]: """Add a remote URL to the list of URLs.""" try: if data.task == "transcription": # Case 1: We are not using remote transcription yet if self.use_remote_transcription is False: self.use_remote_transcription = True - self.transcription_url_handler = URLService(remote_urls=[data.url]) + self.transcription_url_handler = URLService( + remote_urls=[str(data.url)] + ) # Case 2: We are already using remote transcription else: - self.transcription_url_handler.add_url(data.url) + await self.transcription_url_handler.add_url(data.url) elif data.task == "diarization": # Case 1: We are not using remote diarization yet if self.use_remote_diarization is False: self.use_remote_diarization = True - self.diarization_url_handler = URLService(remote_urls=[data.url]) + self.diarization_url_handler = URLService( + remote_urls=[str(data.url)] + ) # Case 2: We are already using remote diarization else: - self.diarization_url_handler.add_url(data.url) + await self.diarization_url_handler.add_url(data.url) else: raise ValueError(f"{data.task} is not a valid task.") @@ -761,7 +803,7 @@ async def remove_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException raise ValueError("You are not using remote transcription.") # Case 2: We are using remote transcription else: - self.transcription_url_handler.remove_url(data.url) + await self.transcription_url_handler.remove_url(str(data.url)) if self.transcription_url_handler.get_queue_size() == 0: # TODO: Add a way to switch back to local transcription pass @@ -772,7 +814,7 @@ async def remove_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException raise ValueError("You are not using remote diarization.") # Case 2: We are using remote diarization else: - self.diarization_url_handler.remove_url(data.url) + await self.diarization_url_handler.remove_url(str(data.url)) if self.diarization_url_handler.get_queue_size() == 0: # TODO: Add a way to switch back to local diarization pass diff --git a/src/wordcab_transcribe/services/concurrency_services.py b/src/wordcab_transcribe/services/concurrency_services.py index 516d767..0e1d133 100644 --- a/src/wordcab_transcribe/services/concurrency_services.py +++ b/src/wordcab_transcribe/services/concurrency_services.py @@ -94,6 +94,15 @@ def get_queue_size(self) -> int: """ return self.queue.qsize() + def get_urls(self) -> List[str]: + """ + Get the list of available URLs. + + Returns: + List[str]: List of available URLs. + """ + return self.remote_urls + async def next_url(self) -> str: """ We use this to iterate equally over the available URLs.