Skip to content

Commit

Permalink
Implement only_transcription and only_diarization (#261)
Browse files Browse the repository at this point in the history
* simplify to cortex only config endpoints

* add a test script for audio-url

* add the transcription only endpoint

* move tensorshare to inference

* fix deps + tests

* rename transcribe endpoint

* delete the audio custom type

* add remote capability for transcription

* update remote models and process

* Fix post-processing and naming

* add diarization only + update post-processing

* update process

* fix logging and Word schema

* remove auto creation of all the services

* fix tests
  • Loading branch information
Thomas Chaigneau authored Oct 4, 2023
1 parent 74fd976 commit f8d6510
Show file tree
Hide file tree
Showing 19 changed files with 849 additions and 463 deletions.
8 changes: 0 additions & 8 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,9 @@ ASR_TYPE="async"
#
# --------------------------------------------- ENDPOINTS CONFIGURATION ---------------------------------------------- #
#
# Include the `audio` endpoint in the API. This endpoint is used to process uploaded local audio files.
AUDIO_FILE_ENDPOINT=True
# Include the `audio-url` endpoint in the API. This endpoint is used to process audio files from a URL.
AUDIO_URL_ENDPOINT=True
# Include the cortex endpoint in the API. This endpoint is used to process audio files from the Cortex API.
# Use this only if you deploy the API using Cortex and Kubernetes.
CORTEX_ENDPOINT=True
# Include the `youtube` endpoint in the API. This endpoint is used to process audio files from YouTube URLs.
YOUTUBE_ENDPOINT=True
# Include the `live` endpoint in the API. This endpoint is used to process live audio streams.
LIVE_ENDPOINT=False
#
# ---------------------------------------- API AUTHENTICATION CONFIGURATION ------------------------------------------ #
# The API authentication is used to control the access to the API endpoints.
Expand Down
30 changes: 30 additions & 0 deletions notebooks/audio_url_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json
import requests

headers = {"accept": "application/json", "Content-Type": "application/json"}
params = {"url": "https://github.com/Wordcab/wordcab-python/raw/main/tests/sample_1.mp3"}


data = {
"offset_start": None,
"offset_end": None,
"num_speakers": -1, # Leave at -1 to guess the number of speakers
"diarization": True, # Longer processing time but speaker segment attribution
"source_lang": "en", # optional, default is "en"
"timestamps": "s", # optional, default is "s". Can be "s", "ms" or "hms".
"internal_vad": False, # optional, default is False
"vocab": ["Martha's Flowers", "Thomas", "Randal"], # optional, default is None
"word_timestamps": False, # optional, default is False
}

response = requests.post(
"http://localhost:5001/api/v1/audio-url",
headers=headers,
params=params,
data=json.dumps(data),
)

r_json = response.json()

with open("data/audio_url_output.json", "w", encoding="utf-8") as f:
json.dump(r_json, f, indent=4, ensure_ascii=False)
124 changes: 124 additions & 0 deletions notebooks/transcribe_endpoint_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import asyncio
import io
import json
from typing import List, Tuple, Union

import aiohttp
import soundfile as sf
import torch
import torchaudio
from pydantic import BaseModel
from tensorshare import Backend, TensorShare


def read_audio(
audio: Union[str, bytes],
offset_start: Union[float, None] = None,
offset_end: Union[float, None] = None,
sample_rate: int = 16000,
) -> Tuple[torch.Tensor, float]:
"""
Read an audio file and return the audio tensor.
Args:
audio (Union[str, bytes]):
Path to the audio file or the audio bytes.
offset_start (Union[float, None], optional):
When to start reading the audio file. Defaults to None.
offset_end (Union[float, None], optional):
When to stop reading the audio file. Defaults to None.
sample_rate (int):
The sample rate of the audio file. Defaults to 16000.
Returns:
Tuple[torch.Tensor, float]: The audio tensor and the audio duration.
"""
if isinstance(audio, str):
wav, sr = torchaudio.load(audio)
elif isinstance(audio, bytes):
with io.BytesIO(audio) as buffer:
wav, sr = sf.read(
buffer, format="RAW", channels=1, samplerate=16000, subtype="PCM_16"
)
wav = torch.from_numpy(wav).unsqueeze(0)
else:
raise ValueError(
f"Invalid audio type. Must be either str or bytes, got: {type(audio)}."
)

if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)

if sr != sample_rate:
transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
wav = transform(wav)
sr = sample_rate

audio_duration = float(wav.shape[1]) / sample_rate

# Convert start and end times to sample indices based on the new sample rate
if offset_start is not None:
start_sample = int(offset_start * sr)
else:
start_sample = 0

if offset_end is not None:
end_sample = int(offset_end * sr)
else:
end_sample = wav.shape[1]

# Trim the audio based on the new start and end samples
wav = wav[:, start_sample:end_sample]

return wav.squeeze(0), audio_duration


class TranscribeRequest(BaseModel):
"""Request model for the transcribe endpoint."""

audio: Union[TensorShare, List[TensorShare]]
compression_ratio_threshold: float
condition_on_previous_text: bool
internal_vad: bool
log_prob_threshold: float
no_speech_threshold: float
repetition_penalty: float
source_lang: str
vocab: Union[List[str], None]


async def main():

audio, _ = read_audio("data/HL_Podcast_1.mp3")
ts = TensorShare.from_dict({"audio": audio}, backend=Backend.TORCH)

data = TranscribeRequest(
audio=ts,
source_lang="en",
compression_ratio_threshold=2.4,
condition_on_previous_text=True,
internal_vad=False,
log_prob_threshold=-1.0,
no_speech_threshold=0.6,
repetition_penalty=1.0,
vocab=None,
)

async with aiohttp.ClientSession() as session:
async with session.post(
url="http://0.0.0.0:5002/api/v1/transcribe",
data=data.model_dump_json(),
headers={"Content-Type": "application/json"},
) as response:
if response.status != 200:
raise Exception(
f"Remote transcription failed with status {response.status}."
)
else:
r = await response.json()

with open("remote_test.json", "w") as f:
f.write(json.dumps(r, indent=4))


asyncio.run(main())
23 changes: 10 additions & 13 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,19 @@ classifiers = [
dependencies = [
"aiohttp>=3.8.4",
"aiofiles>=23.1.0",
"ctranslate2>=3.18.0",
"faster-whisper @ git+https://github.com/Wordcab/faster-whisper@master",
"ffmpeg-python>=0.2.0",
"librosa>=0.9.0",
"loguru>=0.6.0",
"numpy==1.23.1",
"onnxruntime>=1.15.0",
"pydantic>=1.10.9",
"python-dotenv>=1.0.0",
"tensorshare>=0.1.1",
"torch>=2.0.0",
"torchaudio>=2.0.1",
"wget>=3.2.0",
"yt-dlp>=2023.3.4",
]

Expand All @@ -48,16 +57,6 @@ path = "src/wordcab_transcribe/__init__.py"
allow-direct-references = true

[project.optional-dependencies]
inference = [
"ctranslate2>=3.18.0",
"faster-whisper @ git+https://github.com/Wordcab/faster-whisper@master",
"librosa>=0.9.0",
"numpy==1.23.1",
"onnxruntime>=1.15.0",
"torch>=2.0.0",
"torchaudio>=2.0.1",
"wget>=3.2.0",
]
runtime = [
"argon2-cffi>=21.3.0",
"fastapi>=0.96.0",
Expand All @@ -67,7 +66,6 @@ runtime = [
"svix>=0.85.1",
"uvicorn>=0.21.1",
"websockets>=11.0.3",
"wordcab_transcribe[inference]",
]
quality = [
"black>=22.10.0",
Expand All @@ -79,9 +77,8 @@ tests = [
"pytest>=7.4",
"pytest-asyncio>=0.21.1",
"pytest-cov>=4.1",
"wordcab_transcribe[inference]",
]
dev = ["wordcab_transcribe[quality,inference,runtime,tests]"]
dev = ["wordcab_transcribe[quality,runtime,tests]"]

[tool.hatch.envs.runtime]
features = [
Expand Down
30 changes: 1 addition & 29 deletions src/wordcab_transcribe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,8 @@ class Settings:
multiscale_weights: List[float]
# ASR type configuration
asr_type: Literal["async", "live", "only_transcription", "only_diarization"]
# Endpoints configuration
audio_file_endpoint: bool
audio_url_endpoint: bool
# Endpoint configuration
cortex_endpoint: bool
youtube_endpoint: bool
live_endpoint: bool
# API authentication configuration
username: str
password: str
Expand Down Expand Up @@ -134,16 +130,6 @@ def compute_type_must_be_valid(cls, value: str): # noqa: B902, N805

return value

@field_validator("asr_type")
def asr_type_must_be_valid(cls, value: str): # noqa: B902, N805
"""Check that the ASR type is valid."""
if value not in {"async", "live"}:
raise ValueError(
f"{value} is not a valid ASR type. Choose between `async` or `live`."
)

return value

@field_validator("openssl_algorithm")
def openssl_algorithm_must_be_valid(cls, value: str): # noqa: B902, N805
"""Check that the OpenSSL algorithm is valid."""
Expand All @@ -168,16 +154,6 @@ def access_token_expire_minutes_must_be_valid(cls, value: int): # noqa: B902, N

def __post_init__(self):
"""Post initialization checks."""
endpoints = [
self.audio_file_endpoint,
self.audio_url_endpoint,
self.cortex_endpoint,
self.youtube_endpoint,
self.live_endpoint,
]
if not any(endpoints):
raise ValueError("At least one endpoint configuration must be set to True.")

if self.debug is False:
if self.username == "admin" or self.username is None: # noqa: S105
logger.warning(
Expand Down Expand Up @@ -281,11 +257,7 @@ def __post_init__(self):
# ASR type
asr_type=getenv("ASR_TYPE", "async"),
# Endpoints configuration
audio_file_endpoint=getenv("AUDIO_FILE_ENDPOINT", True),
audio_url_endpoint=getenv("AUDIO_URL_ENDPOINT", True),
cortex_endpoint=getenv("CORTEX_ENDPOINT", True),
youtube_endpoint=getenv("YOUTUBE_ENDPOINT", True),
live_endpoint=getenv("LIVE_ENDPOINT", False),
# API authentication configuration
username=getenv("USERNAME", "admin"),
password=getenv("PASSWORD", "admin"),
Expand Down
24 changes: 20 additions & 4 deletions src/wordcab_transcribe/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
from loguru import logger

from wordcab_transcribe.config import settings
from wordcab_transcribe.services.asr_service import ASRAsyncService, ASRLiveService
from wordcab_transcribe.services.asr_service import (
ASRAsyncService,
ASRDiarizationOnly,
ASRLiveService,
ASRTranscriptionOnly,
)
from wordcab_transcribe.utils import (
check_ffmpeg,
download_model,
Expand Down Expand Up @@ -57,9 +62,20 @@
debug_mode=settings.debug,
)
elif settings.asr_type == "only_transcription":
asr = None
asr = ASRTranscriptionOnly(
whisper_model=settings.whisper_model,
compute_type=settings.compute_type,
extra_languages=settings.extra_languages,
extra_languages_model_paths=settings.extra_languages_model_paths,
debug_mode=settings.debug,
)
elif settings.asr_type == "only_diarization":
asr = None
asr = ASRDiarizationOnly(
window_lengths=settings.window_lengths,
shift_lengths=settings.shift_lengths,
multiscale_weights=settings.multiscale_weights,
debug_mode=settings.debug,
)
else:
raise ValueError(f"Invalid ASR type: {settings.asr_type}")

Expand All @@ -75,7 +91,7 @@ async def lifespan(app: FastAPI) -> None:
" https://github.com/Wordcab/wordcab-transcribe/issues"
)

if settings.asr_type == "async" or settings.asr_type == "remote_transcribe":
if settings.asr_type == "async" or settings.asr_type == "only_transcription":
if check_ffmpeg() is False:
logger.warning(
"FFmpeg is not installed on the host machine.\n"
Expand Down
39 changes: 38 additions & 1 deletion src/wordcab_transcribe/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@

"""Logging module to add a logging middleware to the Wordcab Transcribe API."""

import asyncio
import sys
import time
import uuid
from functools import partial
from typing import Any, Awaitable, Callable, Tuple

from loguru import logger
Expand Down Expand Up @@ -93,7 +95,42 @@ def time_and_tell(
The appropriate wrapper for the function.
"""
start_time = time.time()
result = func()
result = func
process_time = time.time() - start_time

if debug_mode:
logger.debug(f"{func_name} executed in {process_time:.4f} secs")

return result, process_time


async def time_and_tell_async(
func: Callable, func_name: str, debug_mode: bool
) -> Tuple[Any, float]:
"""
This decorator logs the execution time of an async function only if the debug setting is True.
Args:
func: The function to call in the wrapper.
func_name: The name of the function for logging purposes.
debug_mode: The debug setting for logging purposes.
Returns:
The appropriate wrapper for the function.
"""
start_time = time.time()

if asyncio.iscoroutinefunction(func) or asyncio.iscoroutine(func):
result = await func
else:
loop = asyncio.get_event_loop()
if isinstance(func, partial):
result = await loop.run_in_executor(
None, func.func, *func.args, **func.keywords
)
else:
result = await loop.run_in_executor(None, func)

process_time = time.time() - start_time

if debug_mode:
Expand Down
Loading

0 comments on commit f8d6510

Please sign in to comment.