Skip to content

Commit

Permalink
Allow local or remote ASR process (#258)
Browse files Browse the repository at this point in the history
* update config

* add local or remote for asr

* fix some schemas

* add missing params to ASRAsync init

* fix the endpoints and asr process

* fix tests

* fix tests

* fix new routers
  • Loading branch information
Thomas Chaigneau authored Sep 29, 2023
1 parent 5c3c148 commit 74fd976
Show file tree
Hide file tree
Showing 19 changed files with 606 additions and 244 deletions.
19 changes: 18 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# The name of the project, used for API documentation.
PROJECT_NAME="Wordcab Transcribe"
# The version of the project, used for API documentation.
VERSION="0.5.0"
VERSION="0.5.1"
# The description of the project, used for API documentation.
DESCRIPTION="💬 ASR FastAPI server using faster-whisper and Auto-Tuning Spectral Clustering for diarization."
# This API prefix is used for all endpoints in the API outside of the status and cortex endpoints.
Expand Down Expand Up @@ -60,6 +60,10 @@ MULTISCALE_WEIGHTS="1.0,1.0,1.0,1.0,1.0"
# files are processed.
# * `live` is the option to use when you want to process a live audio stream. It will process the audio in chunks,
# and return the results as soon as they are available. Live option is still a feature in development.
# * `only_transcription` is used to deploy a single transcription server.
# This option is used when you want to deploy each service in a separate server.
# * `only_diarization` is used to deploy a single diarization server.
# This option is used when you want to deploy each service in a separate server.
# Use `live` only if you need live results, otherwise, use `async`.
ASR_TYPE="async"
#
Expand Down Expand Up @@ -106,4 +110,17 @@ SVIX_API_KEY=
# The svix_app_id parameter is used in the cortex implementation to enable webhooks.
SVIX_APP_ID=
#
# -------------------------------------------------- REMOTE SERVERS -------------------------------------------------- #
# The remote servers configuration is used to control the number of servers used to process the requests if you don't
# want to group all the services in one server.
#
# The TRANSCRIBE_SERVER_URLS parameter is used to control the URLs of the servers used to process the requests.
# Each url should be separated by a comma and have this format: "host:port".
# e.g. SERVER_URLS="http://1.2.3.4:8000,http://4.3.2.1:8000"
TRANSCRIBE_SERVER_URLS=
# The DIARIZE_SERVER_URLS parameter is used to control the URLs of the servers used to process the requests.
# Each url should be separated by a comma and have this format: "host:port".
# e.g. SERVER_URLS="http://1.2.3.4:8000,http://4.3.2.1:8000"
DIARIZE_SERVER_URLS=
#
# -------------------------------------------------------------------------------------------------------------------- #
44 changes: 34 additions & 10 deletions src/wordcab_transcribe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
"""Configuration module of the Wordcab Transcribe."""

from os import getenv
from typing import Dict, List
from typing import Dict, List, Union

from dotenv import load_dotenv
from loguru import logger
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Literal

from wordcab_transcribe import __version__

Expand All @@ -44,14 +45,14 @@ class Settings:
# Whisper
whisper_model: str
compute_type: str
extra_languages: List[str]
extra_languages_model_paths: Dict[str, str]
extra_languages: Union[List[str], None]
extra_languages_model_paths: Union[Dict[str, str], None]
# Diarization
window_lengths: List[float]
shift_lengths: List[float]
multiscale_weights: List[float]
# ASR type configuration
asr_type: str
asr_type: Literal["async", "live", "only_transcription", "only_diarization"]
# Endpoints configuration
audio_file_endpoint: bool
audio_url_endpoint: bool
Expand All @@ -69,6 +70,9 @@ class Settings:
# Svix configuration
svix_api_key: str
svix_app_id: str
# Remote servers configuration
transcribe_server_urls: Union[List[str], None]
diarize_server_urls: Union[List[str], None]

@field_validator("project_name")
def project_name_must_not_be_none(cls, value: str): # noqa: B902, N805
Expand Down Expand Up @@ -213,29 +217,46 @@ def __post_init__(self):
# Extra languages
_extra_languages = getenv("EXTRA_LANGUAGES", None)
if _extra_languages is not None and _extra_languages != "":
extra_languages = _extra_languages.split(",")
extra_languages = [lang.strip() for lang in _extra_languages.split(",")]
else:
extra_languages = []
extra_languages = None

extra_languages_model_paths = (
{lang: "" for lang in extra_languages} if extra_languages is not None else None
)

# Diarization scales
_window_lengths = getenv("WINDOW_LENGTHS", None)
if _window_lengths is not None:
window_lengths = [float(x) for x in _window_lengths.split(",")]
window_lengths = [float(x.strip()) for x in _window_lengths.split(",")]
else:
window_lengths = [1.5, 1.25, 1.0, 0.75, 0.5]

_shift_lengths = getenv("SHIFT_LENGTHS", None)
if _shift_lengths is not None:
shift_lengths = [float(x) for x in _shift_lengths.split(",")]
shift_lengths = [float(x.strip()) for x in _shift_lengths.split(",")]
else:
shift_lengths = [0.75, 0.625, 0.5, 0.375, 0.25]

_multiscale_weights = getenv("MULTISCALE_WEIGHTS", None)
if _multiscale_weights is not None:
multiscale_weights = [float(x) for x in _multiscale_weights.split(",")]
multiscale_weights = [float(x.strip()) for x in _multiscale_weights.split(",")]
else:
multiscale_weights = [1.0, 1.0, 1.0, 1.0, 1.0]

# Multi-servers configuration
_transcribe_server_urls = getenv("TRANSCRIBE_SERVER_URLS", None)
if _transcribe_server_urls is not None and _transcribe_server_urls != "":
transcribe_server_urls = [url.strip() for url in _transcribe_server_urls.split(",")]
else:
transcribe_server_urls = None

_diarize_server_urls = getenv("DIARIZE_SERVER_URLS", None)
if _diarize_server_urls is not None and _diarize_server_urls != "":
diarize_server_urls = [url.strip() for url in _diarize_server_urls.split(",")]
else:
diarize_server_urls = None

settings = Settings(
# General configuration
project_name=getenv("PROJECT_NAME", "Wordcab Transcribe"),
Expand All @@ -252,7 +273,7 @@ def __post_init__(self):
whisper_model=getenv("WHISPER_MODEL", "large-v2"),
compute_type=getenv("COMPUTE_TYPE", "float16"),
extra_languages=extra_languages,
extra_languages_model_paths={lang: "" for lang in extra_languages},
extra_languages_model_paths=extra_languages_model_paths,
# Diarization
window_lengths=window_lengths,
shift_lengths=shift_lengths,
Expand All @@ -276,4 +297,7 @@ def __post_init__(self):
# Svix configuration
svix_api_key=getenv("SVIX_API_KEY", ""),
svix_app_id=getenv("SVIX_APP_ID", ""),
# Remote servers configuration
transcribe_server_urls=transcribe_server_urls,
diarize_server_urls=diarize_server_urls,
)
45 changes: 26 additions & 19 deletions src/wordcab_transcribe/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@
multiscale_weights=settings.multiscale_weights,
extra_languages=settings.extra_languages,
extra_languages_model_paths=settings.extra_languages_model_paths,
transcribe_server_urls=settings.transcribe_server_urls,
diarize_server_urls=settings.diarize_server_urls,
debug_mode=settings.debug,
)
elif settings.asr_type == "only_transcription":
asr = None
elif settings.asr_type == "only_diarization":
asr = None
else:
raise ValueError(f"Invalid ASR type: {settings.asr_type}")

Expand All @@ -69,28 +75,29 @@ async def lifespan(app: FastAPI) -> None:
" https://github.com/Wordcab/wordcab-transcribe/issues"
)

if check_ffmpeg() is False:
logger.warning(
"FFmpeg is not installed on the host machine.\n"
"Please install it and try again: `sudo apt-get install ffmpeg`"
)
exit(1)
if settings.asr_type == "async" or settings.asr_type == "remote_transcribe":
if check_ffmpeg() is False:
logger.warning(
"FFmpeg is not installed on the host machine.\n"
"Please install it and try again: `sudo apt-get install ffmpeg`"
)
exit(1)

if settings.extra_languages:
logger.info("Downloading models for extra languages...")
for model in settings.extra_languages:
try:
model_path = download_model(
compute_type=settings.compute_type, language=model
)
if settings.extra_languages is not None:
logger.info("Downloading models for extra languages...")
for model in settings.extra_languages:
try:
model_path = download_model(
compute_type=settings.compute_type, language=model
)

if model_path is not None:
settings.extra_languages_model_paths[model] = model_path
else:
raise Exception(f"Coudn't download model for {model}")
if model_path is not None:
settings.extra_languages_model_paths[model] = model_path
else:
raise Exception(f"Coudn't download model for {model}")

except Exception as e:
logger.error(f"Error downloading model for {model}: {e}")
except Exception as e:
logger.error(f"Error downloading model for {model}: {e}")

logger.info("Warmup initialization...")
await asr.inference_warmup()
Expand Down
2 changes: 1 addition & 1 deletion src/wordcab_transcribe/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
else:
app.include_router(api_router, prefix=settings.api_prefix)

if settings.cortex_endpoint:
if settings.cortex_endpoint and settings.asr_type == "async":
app.include_router(cortex_router, tags=["cortex"])


Expand Down
28 changes: 19 additions & 9 deletions src/wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
class ProcessTimes(BaseModel):
"""The execution times of the different processes."""

total: float
transcription: float
diarization: Union[float, None]
post_processing: float
total: Union[float, None] = None
transcription: Union[float, None] = None
diarization: Union[float, None] = None
post_processing: Union[float, None] = None


class Timestamps(str, Enum):
Expand Down Expand Up @@ -72,7 +72,7 @@ class BaseResponse(BaseModel):
diarization: bool
source_lang: str
timestamps: str
vocab: List[str]
vocab: Union[List[str], None]
word_timestamps: bool
internal_vad: bool
repetition_penalty: float
Expand Down Expand Up @@ -219,7 +219,7 @@ class CortexPayload(BaseModel):
multi_channel: Optional[bool] = False
source_lang: Optional[str] = "en"
timestamps: Optional[Timestamps] = Timestamps.seconds
vocab: Optional[List[str]] = []
vocab: Union[List[str], None] = None
word_timestamps: Optional[bool] = False
internal_vad: Optional[bool] = False
repetition_penalty: Optional[float] = 1.2
Expand Down Expand Up @@ -386,7 +386,7 @@ class BaseRequest(BaseModel):
diarization: bool = False
source_lang: str = "en"
timestamps: Timestamps = Timestamps.seconds
vocab: List[str] = []
vocab: Union[List[str], None] = None
word_timestamps: bool = False
internal_vad: bool = False
repetition_penalty: float = 1.2
Expand All @@ -397,10 +397,12 @@ class BaseRequest(BaseModel):

@field_validator("vocab")
def validate_each_vocab_value(
cls, value: List[str] # noqa: B902, N805
cls, value: Union[List[str], None] # noqa: B902, N805
) -> List[str]:
"""Validate the value of each vocab field."""
if not all(isinstance(v, str) for v in value):
if value == []:
return None
elif value is not None and not all(isinstance(v, str) for v in value):
raise ValueError("`vocab` must be a list of strings.")

return value
Expand Down Expand Up @@ -480,6 +482,14 @@ class Config:
}


class DiarizeResponse(BaseModel):
"""Response model for the diarize endpoint."""


class TranscribeResponse(BaseModel):
"""Response model for the transcribe endpoint."""


class Token(BaseModel):
"""Token model for authentication."""

Expand Down
72 changes: 72 additions & 0 deletions src/wordcab_transcribe/pydantic_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2023 The Wordcab Team. All rights reserved.
#
# Licensed under the Wordcab Transcribe License 0.1 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://github.com/Wordcab/wordcab-transcribe/blob/main/LICENSE
#
# Except as expressly provided otherwise herein, and to the fullest
# extent permitted by law, Licensor provides the Software (and each
# Contributor provides its Contributions) AS IS, and Licensor
# disclaims all warranties or guarantees of any kind, express or
# implied, whether arising under any law or from any usage in trade,
# or otherwise including but not limited to the implied warranties
# of merchantability, non-infringement, quiet enjoyment, fitness
# for a particular purpose, or otherwise.
#
# See the License for the specific language governing permissions
# and limitations under the License.
"""Custom Pydantic annotations for the Wordcab Transcribe API."""

from typing import Any

import torch
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema


class TorchTensorPydanticAnnotation:
"""Pydantic annotation for torch.Tensor."""

@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
"""Custom validation and serialization for torch.Tensor."""

def validate_tensor(value) -> torch.Tensor:
if not isinstance(value, torch.Tensor):
raise ValueError(f"Expected a torch.Tensor but got {type(value)}")
return value

return core_schema.chain_schema(
[
core_schema.no_info_plain_validator_function(validate_tensor),
]
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# This is a custom representation for the tensor in JSON Schema.
# Here, it represents a tensor as an object with metadata.
return {
"type": "object",
"properties": {
"dtype": {
"type": "string",
"description": "Data type of the tensor",
},
"shape": {
"type": "array",
"items": {"type": "integer"},
"description": "Shape of the tensor",
},
},
"required": ["dtype", "shape"],
}
Loading

0 comments on commit 74fd976

Please sign in to comment.