Skip to content

Commit

Permalink
Merge pull request #287 from Wordcab/284-have-option-to-switch-to-dis…
Browse files Browse the repository at this point in the history
…til-whisper

284 have option to switch to distil whisper
  • Loading branch information
aleksandr-smechov authored Mar 21, 2024
2 parents 8c967bd + 2c58ea1 commit fa83363
Show file tree
Hide file tree
Showing 15 changed files with 1,295 additions and 135 deletions.
7 changes: 7 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ SVIX_API_KEY=
# The svix_app_id parameter is used in the cortex implementation to enable webhooks.
SVIX_APP_ID=
#
# ----------------------------------------------- AWS CONFIGURATION ------------------------------------------------- #
#
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
AWS_STORAGE_BUCKET_NAME=
AWS_S3_REGION_NAME=
#
# -------------------------------------------------- REMOTE SERVERS -------------------------------------------------- #
# The remote servers configuration is used to control the number of servers used to process the requests if you don't
# want to group all the services in one server.
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,17 @@ classifiers = [
dependencies = [
"aiohttp>=3.8.4",
"aiofiles>=23.1.0",
"boto3",
"ctranslate2>=3.18.0",
"faster-whisper @ git+https://github.com/Wordcab/faster-whisper@master",
"ffmpeg-python>=0.2.0",
"transformers@git+https://github.com/huggingface/transformers.git@assistant_decoding_batch",
"librosa>=0.9.0",
"loguru>=0.6.0",
"nltk>=3.8.1",
"numpy==1.23.1",
"onnxruntime>=1.15.0",
"pandas>=2.1.2",
"pydantic>=1.10.9",
"python-dotenv>=1.0.0",
"tensorshare>=0.1.1",
Expand Down
10 changes: 10 additions & 0 deletions src/wordcab_transcribe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ class Settings:
access_token_expire_minutes: int
# Cortex configuration
cortex_api_key: str
# AWS configuration
aws_access_key_id: str
aws_secret_access_key: str
aws_storage_bucket_name: str
aws_region_name: str
# Svix configuration
svix_api_key: str
svix_app_id: str
Expand Down Expand Up @@ -266,6 +271,11 @@ def __post_init__(self):
access_token_expire_minutes=getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30),
# Cortex configuration
cortex_api_key=getenv("WORDCAB_TRANSCRIBE_API_KEY", ""),
# AWS configuration
aws_access_key_id=getenv("AWS_ACCESS_KEY_ID", ""),
aws_secret_access_key=getenv("AWS_SECRET_ACCESS_KEY", ""),
aws_storage_bucket_name=getenv("AWS_STORAGE_BUCKET_NAME", ""),
aws_region_name=getenv("AWS_REGION_NAME", ""),
# Svix configuration
svix_api_key=getenv("SVIX_API_KEY", ""),
svix_app_id=getenv("SVIX_APP_ID", ""),
Expand Down
2 changes: 1 addition & 1 deletion src/wordcab_transcribe/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
# Add logging middleware
app.add_middleware(LoggingMiddleware, debug_mode=settings.debug)

# Include the appropiate routers based on the settings
# Include the appropriate routers based on the settings
if settings.debug is False:
app.include_router(auth_router, tags=["authentication"])
app.include_router(
Expand Down
8 changes: 7 additions & 1 deletion src/wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ class BaseResponse(BaseModel):
no_speech_threshold: float
condition_on_previous_text: bool
process_times: ProcessTimes
job_name: Optional[str] = None
task_token: Optional[str] = None


class AudioResponse(BaseResponse):
Expand Down Expand Up @@ -240,6 +242,7 @@ class CortexPayload(BaseModel):
no_speech_threshold: Optional[float] = 0.6
condition_on_previous_text: Optional[bool] = True
job_name: Optional[str] = None
task_token: Optional[str] = None
ping: Optional[bool] = False

class Config:
Expand Down Expand Up @@ -406,6 +409,8 @@ class BaseRequest(BaseModel):
log_prob_threshold: float = -1.0
no_speech_threshold: float = 0.6
condition_on_previous_text: bool = True
job_name: Optional[str] = None
task_token: Optional[str] = None

@field_validator("vocab")
def validate_each_vocab_value(
Expand Down Expand Up @@ -518,7 +523,8 @@ class DiarizationOutput(BaseModel):
class DiarizationRequest(BaseModel):
"""Request model for the diarize endpoint."""

audio: TensorShare
audio: Union[TensorShare, str]
audio_type: Optional[str]
duration: float
num_speakers: int

Expand Down
236 changes: 169 additions & 67 deletions src/wordcab_transcribe/router/v1/audio_url_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,23 @@
"""Audio url endpoint for the Wordcab Transcribe API."""

import asyncio
import json
from datetime import datetime
from typing import List, Optional, Union

import boto3
import shortuuid
from fastapi import APIRouter, BackgroundTasks, HTTPException
from fastapi import status as http_status
from loguru import logger
from svix.api import MessageIn, SvixAsync

from wordcab_transcribe.config import settings
from wordcab_transcribe.dependencies import asr, download_limit
from wordcab_transcribe.models import AudioRequest, AudioResponse
from wordcab_transcribe.services.asr_service import ProcessException
from wordcab_transcribe.models import (
AudioRequest,
AudioResponse,
)
from wordcab_transcribe.utils import (
check_num_channels,
delete_file,
Expand All @@ -40,86 +47,181 @@
router = APIRouter()


@router.post("", response_model=AudioResponse, status_code=http_status.HTTP_200_OK)
def retrieve_service(service, aws_creds):
return boto3.client(
service,
aws_access_key_id=aws_creds.get("aws_access_key_id"),
aws_secret_access_key=aws_creds.get("aws_secret_access_key"),
region_name=aws_creds.get("region_name"),
)


s3_client = retrieve_service(
"s3",
{
"aws_access_key_id": settings.aws_access_key_id,
"aws_secret_access_key": settings.aws_secret_access_key,
"region_name": settings.aws_region_name,
},
)


@router.post("", status_code=http_status.HTTP_202_ACCEPTED)
async def inference_with_audio_url(
background_tasks: BackgroundTasks,
url: str,
data: Optional[AudioRequest] = None,
) -> AudioResponse:
) -> dict:
"""Inference endpoint with audio url."""
filename = f"audio_url_{shortuuid.ShortUUID().random(length=32)}"

data = AudioRequest() if data is None else AudioRequest(**data.dict())

async with download_limit:
_filepath = await download_audio_file("url", url, filename)
async def process_audio():
try:
async with download_limit:
_filepath = await download_audio_file("url", url, filename)

num_channels = await check_num_channels(_filepath)
if num_channels > 1 and data.multi_channel is False:
num_channels = 1 # Force mono channel if more than 1 channel
num_channels = await check_num_channels(_filepath)
if num_channels > 1 and data.multi_channel is False:
num_channels = 1 # Force mono channel if more than 1 channel

try:
filepath: Union[str, List[str]] = await process_audio_file(
_filepath, num_channels=num_channels
)
try:
filepath: Union[str, List[str]] = await process_audio_file(
_filepath, num_channels=num_channels
)

except Exception as e:
raise HTTPException( # noqa: B904
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Process failed: {e}",
)

background_tasks.add_task(delete_file, filepath=filename)

task = asyncio.create_task(
asr.process_input(
filepath=filepath,
url=url,
url_type="url",
offset_start=data.offset_start,
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
multi_channel=data.multi_channel,
source_lang=data.source_lang,
timestamps_format=data.timestamps,
vocab=data.vocab,
word_timestamps=data.word_timestamps,
internal_vad=data.internal_vad,
repetition_penalty=data.repetition_penalty,
compression_ratio_threshold=data.compression_ratio_threshold,
log_prob_threshold=data.log_prob_threshold,
no_speech_threshold=data.no_speech_threshold,
condition_on_previous_text=data.condition_on_previous_text,
)
)

result = await task
utterances, process_times, audio_duration = result
result = AudioResponse(
utterances=utterances,
audio_duration=audio_duration,
offset_start=data.offset_start,
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
multi_channel=data.multi_channel,
source_lang=data.source_lang,
timestamps=data.timestamps,
vocab=data.vocab,
word_timestamps=data.word_timestamps,
internal_vad=data.internal_vad,
repetition_penalty=data.repetition_penalty,
compression_ratio_threshold=data.compression_ratio_threshold,
log_prob_threshold=data.log_prob_threshold,
no_speech_threshold=data.no_speech_threshold,
condition_on_previous_text=data.condition_on_previous_text,
job_name=data.job_name,
task_token=data.task_token,
process_times=process_times,
)

upload_file(
s3_client,
file=bytes(json.dumps(result.model_dump()).encode("UTF-8")),
bucket=settings.aws_storage_bucket_name,
object_name=f"responses/{data.task_token}_{data.job_name}.json",
)

background_tasks.add_task(delete_file, filepath=filepath)
await send_update_with_svix(
data.job_name,
"finished",
{
"job_name": data.job_name,
"task_token": data.task_token,
},
)
except Exception as e:
raise HTTPException( # noqa: B904
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Process failed: {e}",
)

background_tasks.add_task(delete_file, filepath=filename)

task = asyncio.create_task(
asr.process_input(
filepath=filepath,
offset_start=data.offset_start,
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
multi_channel=data.multi_channel,
source_lang=data.source_lang,
timestamps_format=data.timestamps,
vocab=data.vocab,
word_timestamps=data.word_timestamps,
internal_vad=data.internal_vad,
repetition_penalty=data.repetition_penalty,
compression_ratio_threshold=data.compression_ratio_threshold,
log_prob_threshold=data.log_prob_threshold,
no_speech_threshold=data.no_speech_threshold,
condition_on_previous_text=data.condition_on_previous_text,
)
error_message = f"Error during transcription: {e}"
logger.error(error_message)

error_payload = {
"error": error_message,
"job_name": data.job_name,
"task_token": data.task_token,
}

await send_update_with_svix(data.job_name, "error", error_payload)

# Add the process_audio function to background tasks
background_tasks.add_task(process_audio)

# Return the job name and task token immediately
return {"job_name": data.job_name, "task_token": data.task_token}


def upload_file(s3_client, file, bucket, object_name):
try:
s3_client.put_object(
Body=file,
Bucket=bucket,
Key=object_name,
)
result = await task
except Exception as e:
logger.error(f"Exception while uploading results to S3: {e}")
return False
return True


background_tasks.add_task(delete_file, filepath=filepath)
async def send_update_with_svix(
job_name: str,
status: str,
payload: dict,
payload_retention_period: Optional[int] = 5,
) -> None:
"""
Send the status update to Svix.
if isinstance(result, ProcessException):
logger.error(result.message)
raise HTTPException(
status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(result.message),
Args:
job_name (str): The name of the job.
status (str): The status of the job.
payload (dict): The payload to send.
payload_retention_period (Optional[int], optional): The payload retention period. Defaults to 5.
"""
if settings.svix_api_key and settings.svix_app_id:
svix = SvixAsync(settings.svix_api_key)
await svix.message.create(
settings.svix_app_id,
MessageIn(
event_type=f"async_job.wordcab_transcribe.{status}",
event_id=f"wordcab_transcribe_{status}_{job_name}_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')}",
payload_retention_period=payload_retention_period,
payload=payload,
),
)
else:
utterances, process_times, audio_duration = result
return AudioResponse(
utterances=utterances,
audio_duration=audio_duration,
offset_start=data.offset_start,
offset_end=data.offset_end,
num_speakers=data.num_speakers,
diarization=data.diarization,
multi_channel=data.multi_channel,
source_lang=data.source_lang,
timestamps=data.timestamps,
vocab=data.vocab,
word_timestamps=data.word_timestamps,
internal_vad=data.internal_vad,
repetition_penalty=data.repetition_penalty,
compression_ratio_threshold=data.compression_ratio_threshold,
log_prob_threshold=data.log_prob_threshold,
no_speech_threshold=data.no_speech_threshold,
condition_on_previous_text=data.condition_on_previous_text,
process_times=process_times,
logger.warning(
"Svix API key and app ID are not set. Cannot send the status update to"
" Svix."
)
4 changes: 2 additions & 2 deletions src/wordcab_transcribe/router/v1/cortex_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
response_model=Union[
CortexError, CortexUrlResponse, CortexYoutubeResponse, PongResponse
],
status_code=http_status.HTTP_200_OK,
status_code=http_status.HTTP_202_ACCEPTED,
)
async def run_cortex(
payload: CortexPayload, request: Request
Expand Down Expand Up @@ -137,7 +137,7 @@ async def run_cortex(
return CortexError(message=error_message)

_cortex_response = {
**response.model_dump(),
**response,
"job_name": payload.job_name,
"request_id": request_id,
}
Expand Down
Loading

0 comments on commit fa83363

Please sign in to comment.