Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow local or remote ASR process #258

Merged
merged 8 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion .env
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# The name of the project, used for API documentation.
PROJECT_NAME="Wordcab Transcribe"
# The version of the project, used for API documentation.
VERSION="0.5.0"
VERSION="0.5.1"
# The description of the project, used for API documentation.
DESCRIPTION="💬 ASR FastAPI server using faster-whisper and Auto-Tuning Spectral Clustering for diarization."
# This API prefix is used for all endpoints in the API outside of the status and cortex endpoints.
Expand Down Expand Up @@ -60,6 +60,10 @@ MULTISCALE_WEIGHTS="1.0,1.0,1.0,1.0,1.0"
# files are processed.
# * `live` is the option to use when you want to process a live audio stream. It will process the audio in chunks,
# and return the results as soon as they are available. Live option is still a feature in development.
# * `only_transcription` is used to deploy a single transcription server.
# This option is used when you want to deploy each service in a separate server.
# * `only_diarization` is used to deploy a single diarization server.
# This option is used when you want to deploy each service in a separate server.
# Use `live` only if you need live results, otherwise, use `async`.
ASR_TYPE="async"
#
Expand Down Expand Up @@ -106,4 +110,17 @@ SVIX_API_KEY=
# The svix_app_id parameter is used in the cortex implementation to enable webhooks.
SVIX_APP_ID=
#
# -------------------------------------------------- REMOTE SERVERS -------------------------------------------------- #
# The remote servers configuration is used to control the number of servers used to process the requests if you don't
# want to group all the services in one server.
#
# The TRANSCRIBE_SERVER_URLS parameter is used to control the URLs of the servers used to process the requests.
# Each url should be separated by a comma and have this format: "host:port".
# e.g. SERVER_URLS="http://1.2.3.4:8000,http://4.3.2.1:8000"
TRANSCRIBE_SERVER_URLS=
# The DIARIZE_SERVER_URLS parameter is used to control the URLs of the servers used to process the requests.
# Each url should be separated by a comma and have this format: "host:port".
# e.g. SERVER_URLS="http://1.2.3.4:8000,http://4.3.2.1:8000"
DIARIZE_SERVER_URLS=
#
# -------------------------------------------------------------------------------------------------------------------- #
44 changes: 34 additions & 10 deletions src/wordcab_transcribe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
"""Configuration module of the Wordcab Transcribe."""

from os import getenv
from typing import Dict, List
from typing import Dict, List, Union

from dotenv import load_dotenv
from loguru import logger
from pydantic import field_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Literal

from wordcab_transcribe import __version__

Expand All @@ -44,14 +45,14 @@ class Settings:
# Whisper
whisper_model: str
compute_type: str
extra_languages: List[str]
extra_languages_model_paths: Dict[str, str]
extra_languages: Union[List[str], None]
extra_languages_model_paths: Union[Dict[str, str], None]
# Diarization
window_lengths: List[float]
shift_lengths: List[float]
multiscale_weights: List[float]
# ASR type configuration
asr_type: str
asr_type: Literal["async", "live", "only_transcription", "only_diarization"]
# Endpoints configuration
audio_file_endpoint: bool
audio_url_endpoint: bool
Expand All @@ -69,6 +70,9 @@ class Settings:
# Svix configuration
svix_api_key: str
svix_app_id: str
# Remote servers configuration
transcribe_server_urls: Union[List[str], None]
diarize_server_urls: Union[List[str], None]

@field_validator("project_name")
def project_name_must_not_be_none(cls, value: str): # noqa: B902, N805
Expand Down Expand Up @@ -213,29 +217,46 @@ def __post_init__(self):
# Extra languages
_extra_languages = getenv("EXTRA_LANGUAGES", None)
if _extra_languages is not None and _extra_languages != "":
extra_languages = _extra_languages.split(",")
extra_languages = [lang.strip() for lang in _extra_languages.split(",")]
else:
extra_languages = []
extra_languages = None

extra_languages_model_paths = (
{lang: "" for lang in extra_languages} if extra_languages is not None else None
)

# Diarization scales
_window_lengths = getenv("WINDOW_LENGTHS", None)
if _window_lengths is not None:
window_lengths = [float(x) for x in _window_lengths.split(",")]
window_lengths = [float(x.strip()) for x in _window_lengths.split(",")]
else:
window_lengths = [1.5, 1.25, 1.0, 0.75, 0.5]

_shift_lengths = getenv("SHIFT_LENGTHS", None)
if _shift_lengths is not None:
shift_lengths = [float(x) for x in _shift_lengths.split(",")]
shift_lengths = [float(x.strip()) for x in _shift_lengths.split(",")]
else:
shift_lengths = [0.75, 0.625, 0.5, 0.375, 0.25]

_multiscale_weights = getenv("MULTISCALE_WEIGHTS", None)
if _multiscale_weights is not None:
multiscale_weights = [float(x) for x in _multiscale_weights.split(",")]
multiscale_weights = [float(x.strip()) for x in _multiscale_weights.split(",")]
else:
multiscale_weights = [1.0, 1.0, 1.0, 1.0, 1.0]

# Multi-servers configuration
_transcribe_server_urls = getenv("TRANSCRIBE_SERVER_URLS", None)
if _transcribe_server_urls is not None and _transcribe_server_urls != "":
transcribe_server_urls = [url.strip() for url in _transcribe_server_urls.split(",")]
else:
transcribe_server_urls = None

_diarize_server_urls = getenv("DIARIZE_SERVER_URLS", None)
if _diarize_server_urls is not None and _diarize_server_urls != "":
diarize_server_urls = [url.strip() for url in _diarize_server_urls.split(",")]
else:
diarize_server_urls = None

settings = Settings(
# General configuration
project_name=getenv("PROJECT_NAME", "Wordcab Transcribe"),
Expand All @@ -252,7 +273,7 @@ def __post_init__(self):
whisper_model=getenv("WHISPER_MODEL", "large-v2"),
compute_type=getenv("COMPUTE_TYPE", "float16"),
extra_languages=extra_languages,
extra_languages_model_paths={lang: "" for lang in extra_languages},
extra_languages_model_paths=extra_languages_model_paths,
# Diarization
window_lengths=window_lengths,
shift_lengths=shift_lengths,
Expand All @@ -276,4 +297,7 @@ def __post_init__(self):
# Svix configuration
svix_api_key=getenv("SVIX_API_KEY", ""),
svix_app_id=getenv("SVIX_APP_ID", ""),
# Remote servers configuration
transcribe_server_urls=transcribe_server_urls,
diarize_server_urls=diarize_server_urls,
)
45 changes: 26 additions & 19 deletions src/wordcab_transcribe/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,14 @@
multiscale_weights=settings.multiscale_weights,
extra_languages=settings.extra_languages,
extra_languages_model_paths=settings.extra_languages_model_paths,
transcribe_server_urls=settings.transcribe_server_urls,
diarize_server_urls=settings.diarize_server_urls,
debug_mode=settings.debug,
)
elif settings.asr_type == "only_transcription":
asr = None
elif settings.asr_type == "only_diarization":
asr = None
else:
raise ValueError(f"Invalid ASR type: {settings.asr_type}")

Expand All @@ -69,28 +75,29 @@ async def lifespan(app: FastAPI) -> None:
" https://github.com/Wordcab/wordcab-transcribe/issues"
)

if check_ffmpeg() is False:
logger.warning(
"FFmpeg is not installed on the host machine.\n"
"Please install it and try again: `sudo apt-get install ffmpeg`"
)
exit(1)
if settings.asr_type == "async" or settings.asr_type == "remote_transcribe":
if check_ffmpeg() is False:
logger.warning(
"FFmpeg is not installed on the host machine.\n"
"Please install it and try again: `sudo apt-get install ffmpeg`"
)
exit(1)

if settings.extra_languages:
logger.info("Downloading models for extra languages...")
for model in settings.extra_languages:
try:
model_path = download_model(
compute_type=settings.compute_type, language=model
)
if settings.extra_languages is not None:
logger.info("Downloading models for extra languages...")
for model in settings.extra_languages:
try:
model_path = download_model(
compute_type=settings.compute_type, language=model
)

if model_path is not None:
settings.extra_languages_model_paths[model] = model_path
else:
raise Exception(f"Coudn't download model for {model}")
if model_path is not None:
settings.extra_languages_model_paths[model] = model_path
else:
raise Exception(f"Coudn't download model for {model}")

except Exception as e:
logger.error(f"Error downloading model for {model}: {e}")
except Exception as e:
logger.error(f"Error downloading model for {model}: {e}")

logger.info("Warmup initialization...")
await asr.inference_warmup()
Expand Down
2 changes: 1 addition & 1 deletion src/wordcab_transcribe/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
else:
app.include_router(api_router, prefix=settings.api_prefix)

if settings.cortex_endpoint:
if settings.cortex_endpoint and settings.asr_type == "async":
app.include_router(cortex_router, tags=["cortex"])


Expand Down
28 changes: 19 additions & 9 deletions src/wordcab_transcribe/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
class ProcessTimes(BaseModel):
"""The execution times of the different processes."""

total: float
transcription: float
diarization: Union[float, None]
post_processing: float
total: Union[float, None] = None
transcription: Union[float, None] = None
diarization: Union[float, None] = None
post_processing: Union[float, None] = None


class Timestamps(str, Enum):
Expand Down Expand Up @@ -72,7 +72,7 @@ class BaseResponse(BaseModel):
diarization: bool
source_lang: str
timestamps: str
vocab: List[str]
vocab: Union[List[str], None]
word_timestamps: bool
internal_vad: bool
repetition_penalty: float
Expand Down Expand Up @@ -219,7 +219,7 @@ class CortexPayload(BaseModel):
multi_channel: Optional[bool] = False
source_lang: Optional[str] = "en"
timestamps: Optional[Timestamps] = Timestamps.seconds
vocab: Optional[List[str]] = []
vocab: Union[List[str], None] = None
word_timestamps: Optional[bool] = False
internal_vad: Optional[bool] = False
repetition_penalty: Optional[float] = 1.2
Expand Down Expand Up @@ -386,7 +386,7 @@ class BaseRequest(BaseModel):
diarization: bool = False
source_lang: str = "en"
timestamps: Timestamps = Timestamps.seconds
vocab: List[str] = []
vocab: Union[List[str], None] = None
word_timestamps: bool = False
internal_vad: bool = False
repetition_penalty: float = 1.2
Expand All @@ -397,10 +397,12 @@ class BaseRequest(BaseModel):

@field_validator("vocab")
def validate_each_vocab_value(
cls, value: List[str] # noqa: B902, N805
cls, value: Union[List[str], None] # noqa: B902, N805
) -> List[str]:
"""Validate the value of each vocab field."""
if not all(isinstance(v, str) for v in value):
if value == []:
return None
elif value is not None and not all(isinstance(v, str) for v in value):
raise ValueError("`vocab` must be a list of strings.")

return value
Expand Down Expand Up @@ -480,6 +482,14 @@ class Config:
}


class DiarizeResponse(BaseModel):
"""Response model for the diarize endpoint."""


class TranscribeResponse(BaseModel):
"""Response model for the transcribe endpoint."""


class Token(BaseModel):
"""Token model for authentication."""

Expand Down
72 changes: 72 additions & 0 deletions src/wordcab_transcribe/pydantic_annotations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2023 The Wordcab Team. All rights reserved.
#
# Licensed under the Wordcab Transcribe License 0.1 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://github.com/Wordcab/wordcab-transcribe/blob/main/LICENSE
#
# Except as expressly provided otherwise herein, and to the fullest
# extent permitted by law, Licensor provides the Software (and each
# Contributor provides its Contributions) AS IS, and Licensor
# disclaims all warranties or guarantees of any kind, express or
# implied, whether arising under any law or from any usage in trade,
# or otherwise including but not limited to the implied warranties
# of merchantability, non-infringement, quiet enjoyment, fitness
# for a particular purpose, or otherwise.
#
# See the License for the specific language governing permissions
# and limitations under the License.
"""Custom Pydantic annotations for the Wordcab Transcribe API."""

from typing import Any

import torch
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema


class TorchTensorPydanticAnnotation:
"""Pydantic annotation for torch.Tensor."""

@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
"""Custom validation and serialization for torch.Tensor."""

def validate_tensor(value) -> torch.Tensor:
if not isinstance(value, torch.Tensor):
raise ValueError(f"Expected a torch.Tensor but got {type(value)}")
return value

return core_schema.chain_schema(
[
core_schema.no_info_plain_validator_function(validate_tensor),
]
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
# This is a custom representation for the tensor in JSON Schema.
# Here, it represents a tensor as an object with metadata.
return {
"type": "object",
"properties": {
"dtype": {
"type": "string",
"description": "Data type of the tensor",
},
"shape": {
"type": "array",
"items": {"type": "integer"},
"description": "Shape of the tensor",
},
},
"required": ["dtype", "shape"],
}
Loading