Skip to content

Commit

Permalink
chore: handle "whisper-1" model name
Browse files Browse the repository at this point in the history
  • Loading branch information
Fedir Zadniprovskyi committed Jun 12, 2024
1 parent 40cedac commit a83bc4f
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions faster_whisper_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
FastAPI,
Form,
HTTPException,
Path,
Query,
Response,
UploadFile,
Expand All @@ -22,6 +23,7 @@
from faster_whisper import WhisperModel
from faster_whisper.vad import VadOptions, get_speech_timestamps
from huggingface_hub.hf_api import ModelInfo
from pydantic import AfterValidator

from faster_whisper_server import utils
from faster_whisper_server.asr import FasterWhisperASR
Expand Down Expand Up @@ -85,7 +87,7 @@ def health() -> Response:
return Response(status_code=200, content="OK")


@app.get("/v1/models", response_model=list[ModelObject])
@app.get("/v1/models")
def get_models() -> list[ModelObject]:
models = huggingface_hub.list_models(library="ctranslate2")
models = [
Expand All @@ -101,8 +103,8 @@ def get_models() -> list[ModelObject]:
return models


@app.get("/v1/models/{model_name:path}", response_model=ModelObject)
def get_model(model_name: str) -> ModelObject:
@app.get("/v1/models/{model_name:path}")
def get_model(model_name: Annotated[str, Path()]) -> ModelObject:
models = list(
huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
)
Expand Down Expand Up @@ -131,10 +133,25 @@ def format_as_sse(data: str) -> str:
return f"data: {data}\n\n"


def handle_default_openai_model(model_name: str) -> str:
"""This exists because some callers may not be able override the default("whisper-1") model name.
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
"""
if model_name == "whisper-1":
logger.info(
f"{model_name} is not a valid model name. Using {config.whisper.model} instead."
)
return config.whisper.model
return model_name


ModelName = Annotated[str, AfterValidator(handle_default_openai_model)]


@app.post("/v1/audio/translations")
def translate_file(
file: Annotated[UploadFile, Form()],
model: Annotated[str, Form()] = config.whisper.model,
model: Annotated[ModelName, Form()] = config.whisper.model,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
temperature: Annotated[float, Form()] = 0.0,
Expand Down Expand Up @@ -187,7 +204,7 @@ def segment_responses() -> Generator[str, None, None]:
@app.post("/v1/audio/transcriptions")
def transcribe_file(
file: Annotated[UploadFile, Form()],
model: Annotated[str, Form()] = config.whisper.model,
model: Annotated[ModelName, Form()] = config.whisper.model,
language: Annotated[Language | None, Form()] = config.default_language,
prompt: Annotated[str | None, Form()] = None,
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
Expand Down Expand Up @@ -289,7 +306,7 @@ async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
@app.websocket("/v1/audio/transcriptions")
async def transcribe_stream(
ws: WebSocket,
model: Annotated[str, Query()] = config.whisper.model,
model: Annotated[ModelName, Query()] = config.whisper.model,
language: Annotated[Language | None, Query()] = config.default_language,
response_format: Annotated[
ResponseFormat, Query()
Expand Down

0 comments on commit a83bc4f

Please sign in to comment.