Skip to content

Commit

Permalink
add get_url + update endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
chainyo committed Oct 4, 2023
1 parent e7fcb65 commit 5c62df3
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 39 deletions.
7 changes: 0 additions & 7 deletions src/wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,6 @@ class Config:
}


class UrlRequest(BaseModel):
"""Request model for the add_url endpoint."""

task: Literal["transcription", "diarization"]
url: str


class UrlSchema(BaseModel):
"""Request model for the add_url endpoint."""

Expand Down
2 changes: 1 addition & 1 deletion src/wordcab_transcribe/router/v1/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
diarize_routers = (diarize_router, "/diarize", "diarization")
manage_remote_url_routers = (
manage_remote_url_router,
"/remote-url",
"/url",
"remote-url",
)

Expand Down
58 changes: 34 additions & 24 deletions src/wordcab_transcribe/router/v1/manage_remote_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,36 +19,54 @@
# and limitations under the License.
"""Add Remote URL endpoint for remote transcription or diarization."""

from typing import Union
from typing import List, Union

from fastapi import APIRouter, HTTPException
from fastapi import status as http_status
from loguru import logger
from pydantic import ValidationError
from pydantic import HttpUrl
from typing_extensions import Literal

from wordcab_transcribe.dependencies import asr
from wordcab_transcribe.models import UrlRequest, UrlSchema
from wordcab_transcribe.services.asr_service import ProcessException
from wordcab_transcribe.models import UrlSchema
from wordcab_transcribe.services.asr_service import ExceptionSource, ProcessException

router = APIRouter()


@router.get(
"",
response_model=Union[List[HttpUrl], str],
status_code=http_status.HTTP_200_OK,
)
async def get_url(task: Literal["transcription", "diarization"]) -> List[HttpUrl]:
"""Get Remote URL endpoint for remote transcription or diarization."""
result: List[UrlSchema] = await asr.get_url(task)

if isinstance(result, ProcessException):
logger.error(result.message)
if result.source == ExceptionSource.get_url:
raise HTTPException(
status_code=http_status.HTTP_405_METHOD_NOT_ALLOWED,
detail=str(result.message),
)
else:
raise HTTPException(
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(result.message),
)

return result


@router.post(
"/add",
response_model=Union[UrlSchema, str],
status_code=http_status.HTTP_200_OK,
)
async def add_url(data: UrlRequest) -> UrlSchema:
async def add_url(data: UrlSchema) -> UrlSchema:
"""Add Remote URL endpoint for remote transcription or diarization."""
try:
_data = UrlSchema(task=data.task, url=data.url)
except ValidationError as e:
raise HTTPException( # noqa: B904
status_code=http_status.HTTP_400_BAD_REQUEST,
detail=f"Your request is invalid: {e}",
)

result: UrlSchema = await asr.add_url(_data)
result: UrlSchema = await asr.add_url(data)

if isinstance(result, ProcessException):
logger.error(result.message)
Expand All @@ -65,17 +83,9 @@ async def add_url(data: UrlRequest) -> UrlSchema:
response_model=Union[UrlSchema, str],
status_code=http_status.HTTP_200_OK,
)
async def remove_url(data: UrlRequest) -> UrlSchema:
async def remove_url(data: UrlSchema) -> UrlSchema:
"""Remove Remote URL endpoint for remote transcription or diarization."""
try:
_data = UrlSchema(task=data.task, url=data.url)
except ValidationError as e:
raise HTTPException( # noqa: B904
status_code=http_status.HTTP_400_BAD_REQUEST,
detail=f"Your request is invalid: {e}",
)

result: UrlSchema = await asr.remove_url(_data)
result: UrlSchema = await asr.remove_url(data)

if isinstance(result, ProcessException):
logger.error(result.message)
Expand Down
56 changes: 49 additions & 7 deletions src/wordcab_transcribe/services/asr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from loguru import logger
from pydantic import BaseModel, ConfigDict
from tensorshare import Backend, TensorShare
from typing_extensions import Literal

from wordcab_transcribe.logging import time_and_tell, time_and_tell_async
from wordcab_transcribe.models import (
Expand All @@ -58,6 +59,7 @@ class ExceptionSource(str, Enum):

add_url = "add_url"
diarization = "diarization"
get_url = "get_url"
post_processing = "post_processing"
remove_url = "remove_url"
transcription = "transcription"
Expand Down Expand Up @@ -720,26 +722,66 @@ async def remote_diarization(
else:
return DiarizationOutput(**await response.json())

async def add_url(self, data: UrlSchema) -> UrlSchema:
async def get_url(
self, task: Literal["transcription", "diarization"]
) -> Union[List[str], ProcessException]:
"""Get the list of remote URLs."""
try:
if task == "transcription":
# Case 1: We are not using remote transcription
if self.use_remote_transcription is False:
return ProcessException(
source=ExceptionSource.get_url,
message="You are not using remote transcription.",
)
# Case 2: We are using remote transcription
else:
return self.transcription_url_handler.get_urls()

elif task == "diarization":
# Case 1: We are not using remote diarization
if self.use_remote_diarization is False:
return ProcessException(
source=ExceptionSource.get_url,
message="You are not using remote diarization.",
)
# Case 2: We are using remote diarization
else:
return self.diarization_url_handler.get_urls()

else:
raise ValueError(f"{task} is not a valid task.")

except Exception as e:
return ProcessException(
source=ExceptionSource.get_url,
message=f"Error in getting URL: {e}\n{traceback.format_exc()}",
)

async def add_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException]:
"""Add a remote URL to the list of URLs."""
try:
if data.task == "transcription":
# Case 1: We are not using remote transcription yet
if self.use_remote_transcription is False:
self.use_remote_transcription = True
self.transcription_url_handler = URLService(remote_urls=[data.url])
self.transcription_url_handler = URLService(
remote_urls=[str(data.url)]
)
# Case 2: We are already using remote transcription
else:
self.transcription_url_handler.add_url(data.url)
await self.transcription_url_handler.add_url(data.url)

elif data.task == "diarization":
# Case 1: We are not using remote diarization yet
if self.use_remote_diarization is False:
self.use_remote_diarization = True
self.diarization_url_handler = URLService(remote_urls=[data.url])
self.diarization_url_handler = URLService(
remote_urls=[str(data.url)]
)
# Case 2: We are already using remote diarization
else:
self.diarization_url_handler.add_url(data.url)
await self.diarization_url_handler.add_url(data.url)

else:
raise ValueError(f"{data.task} is not a valid task.")
Expand All @@ -761,7 +803,7 @@ async def remove_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException
raise ValueError("You are not using remote transcription.")
# Case 2: We are using remote transcription
else:
self.transcription_url_handler.remove_url(data.url)
await self.transcription_url_handler.remove_url(str(data.url))
if self.transcription_url_handler.get_queue_size() == 0:
# TODO: Add a way to switch back to local transcription
pass
Expand All @@ -772,7 +814,7 @@ async def remove_url(self, data: UrlSchema) -> Union[UrlSchema, ProcessException
raise ValueError("You are not using remote diarization.")
# Case 2: We are using remote diarization
else:
self.diarization_url_handler.remove_url(data.url)
await self.diarization_url_handler.remove_url(str(data.url))
if self.diarization_url_handler.get_queue_size() == 0:
# TODO: Add a way to switch back to local diarization
pass
Expand Down
9 changes: 9 additions & 0 deletions src/wordcab_transcribe/services/concurrency_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ def get_queue_size(self) -> int:
"""
return self.queue.qsize()

def get_urls(self) -> List[str]:
"""
Get the list of available URLs.
Returns:
List[str]: List of available URLs.
"""
return self.remote_urls

async def next_url(self) -> str:
"""
We use this to iterate equally over the available URLs.
Expand Down

0 comments on commit 5c62df3

Please sign in to comment.