diff --git a/.env b/.env index e069d39..6a7c1ce 100644 --- a/.env +++ b/.env @@ -102,6 +102,13 @@ SVIX_API_KEY= # The svix_app_id parameter is used in the cortex implementation to enable webhooks. SVIX_APP_ID= # +# ----------------------------------------------- AWS CONFIGURATION ------------------------------------------------- # +# +AWS_ACCESS_KEY_ID= +AWS_SECRET_ACCESS_KEY= +AWS_STORAGE_BUCKET_NAME= +AWS_S3_REGION_NAME= +# # -------------------------------------------------- REMOTE SERVERS -------------------------------------------------- # # The remote servers configuration is used to control the number of servers used to process the requests if you don't # want to group all the services in one server. diff --git a/pyproject.toml b/pyproject.toml index ba0adfc..e968d3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,13 +29,17 @@ classifiers = [ dependencies = [ "aiohttp>=3.8.4", "aiofiles>=23.1.0", + "boto3", "ctranslate2>=3.18.0", "faster-whisper @ git+https://github.com/Wordcab/faster-whisper@master", "ffmpeg-python>=0.2.0", + "transformers@git+https://github.com/huggingface/transformers.git@assistant_decoding_batch", "librosa>=0.9.0", "loguru>=0.6.0", + "nltk>=3.8.1", "numpy==1.23.1", "onnxruntime>=1.15.0", + "pandas>=2.1.2", "pydantic>=1.10.9", "python-dotenv>=1.0.0", "tensorshare>=0.1.1", diff --git a/src/wordcab_transcribe/config.py b/src/wordcab_transcribe/config.py index 5d992ae..4a370fb 100644 --- a/src/wordcab_transcribe/config.py +++ b/src/wordcab_transcribe/config.py @@ -63,6 +63,11 @@ class Settings: access_token_expire_minutes: int # Cortex configuration cortex_api_key: str + # AWS configuration + aws_access_key_id: str + aws_secret_access_key: str + aws_storage_bucket_name: str + aws_region_name: str # Svix configuration svix_api_key: str svix_app_id: str @@ -266,6 +271,11 @@ def __post_init__(self): access_token_expire_minutes=getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30), # Cortex configuration cortex_api_key=getenv("WORDCAB_TRANSCRIBE_API_KEY", ""), + # AWS configuration + aws_access_key_id=getenv("AWS_ACCESS_KEY_ID", ""), + aws_secret_access_key=getenv("AWS_SECRET_ACCESS_KEY", ""), + aws_storage_bucket_name=getenv("AWS_STORAGE_BUCKET_NAME", ""), + aws_region_name=getenv("AWS_REGION_NAME", ""), # Svix configuration svix_api_key=getenv("SVIX_API_KEY", ""), svix_app_id=getenv("SVIX_APP_ID", ""), diff --git a/src/wordcab_transcribe/main.py b/src/wordcab_transcribe/main.py index 4acdbb5..2ae41f9 100644 --- a/src/wordcab_transcribe/main.py +++ b/src/wordcab_transcribe/main.py @@ -46,7 +46,7 @@ # Add logging middleware app.add_middleware(LoggingMiddleware, debug_mode=settings.debug) -# Include the appropiate routers based on the settings +# Include the appropriate routers based on the settings if settings.debug is False: app.include_router(auth_router, tags=["authentication"]) app.include_router( diff --git a/src/wordcab_transcribe/models.py b/src/wordcab_transcribe/models.py index 8203a04..a774278 100644 --- a/src/wordcab_transcribe/models.py +++ b/src/wordcab_transcribe/models.py @@ -93,6 +93,8 @@ class BaseResponse(BaseModel): no_speech_threshold: float condition_on_previous_text: bool process_times: ProcessTimes + job_name: Optional[str] = None + task_token: Optional[str] = None class AudioResponse(BaseResponse): @@ -240,6 +242,7 @@ class CortexPayload(BaseModel): no_speech_threshold: Optional[float] = 0.6 condition_on_previous_text: Optional[bool] = True job_name: Optional[str] = None + task_token: Optional[str] = None ping: Optional[bool] = False class Config: @@ -406,6 +409,8 @@ class BaseRequest(BaseModel): log_prob_threshold: float = -1.0 no_speech_threshold: float = 0.6 condition_on_previous_text: bool = True + job_name: Optional[str] = None + task_token: Optional[str] = None @field_validator("vocab") def validate_each_vocab_value( @@ -518,7 +523,8 @@ class DiarizationOutput(BaseModel): class DiarizationRequest(BaseModel): """Request model for the diarize endpoint.""" - audio: TensorShare + audio: Union[TensorShare, str] + audio_type: Optional[str] duration: float num_speakers: int diff --git a/src/wordcab_transcribe/router/v1/audio_url_endpoint.py b/src/wordcab_transcribe/router/v1/audio_url_endpoint.py index ecc350e..ef6f0aa 100644 --- a/src/wordcab_transcribe/router/v1/audio_url_endpoint.py +++ b/src/wordcab_transcribe/router/v1/audio_url_endpoint.py @@ -20,16 +20,23 @@ """Audio url endpoint for the Wordcab Transcribe API.""" import asyncio +import json +from datetime import datetime from typing import List, Optional, Union +import boto3 import shortuuid from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import status as http_status from loguru import logger +from svix.api import MessageIn, SvixAsync +from wordcab_transcribe.config import settings from wordcab_transcribe.dependencies import asr, download_limit -from wordcab_transcribe.models import AudioRequest, AudioResponse -from wordcab_transcribe.services.asr_service import ProcessException +from wordcab_transcribe.models import ( + AudioRequest, + AudioResponse, +) from wordcab_transcribe.utils import ( check_num_channels, delete_file, @@ -40,86 +47,181 @@ router = APIRouter() -@router.post("", response_model=AudioResponse, status_code=http_status.HTTP_200_OK) +def retrieve_service(service, aws_creds): + return boto3.client( + service, + aws_access_key_id=aws_creds.get("aws_access_key_id"), + aws_secret_access_key=aws_creds.get("aws_secret_access_key"), + region_name=aws_creds.get("region_name"), + ) + + +s3_client = retrieve_service( + "s3", + { + "aws_access_key_id": settings.aws_access_key_id, + "aws_secret_access_key": settings.aws_secret_access_key, + "region_name": settings.aws_region_name, + }, +) + + +@router.post("", status_code=http_status.HTTP_202_ACCEPTED) async def inference_with_audio_url( background_tasks: BackgroundTasks, url: str, data: Optional[AudioRequest] = None, -) -> AudioResponse: +) -> dict: """Inference endpoint with audio url.""" filename = f"audio_url_{shortuuid.ShortUUID().random(length=32)}" - data = AudioRequest() if data is None else AudioRequest(**data.dict()) - async with download_limit: - _filepath = await download_audio_file("url", url, filename) + async def process_audio(): + try: + async with download_limit: + _filepath = await download_audio_file("url", url, filename) - num_channels = await check_num_channels(_filepath) - if num_channels > 1 and data.multi_channel is False: - num_channels = 1 # Force mono channel if more than 1 channel + num_channels = await check_num_channels(_filepath) + if num_channels > 1 and data.multi_channel is False: + num_channels = 1 # Force mono channel if more than 1 channel - try: - filepath: Union[str, List[str]] = await process_audio_file( - _filepath, num_channels=num_channels - ) + try: + filepath: Union[str, List[str]] = await process_audio_file( + _filepath, num_channels=num_channels + ) + + except Exception as e: + raise HTTPException( # noqa: B904 + status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Process failed: {e}", + ) + + background_tasks.add_task(delete_file, filepath=filename) + task = asyncio.create_task( + asr.process_input( + filepath=filepath, + url=url, + url_type="url", + offset_start=data.offset_start, + offset_end=data.offset_end, + num_speakers=data.num_speakers, + diarization=data.diarization, + multi_channel=data.multi_channel, + source_lang=data.source_lang, + timestamps_format=data.timestamps, + vocab=data.vocab, + word_timestamps=data.word_timestamps, + internal_vad=data.internal_vad, + repetition_penalty=data.repetition_penalty, + compression_ratio_threshold=data.compression_ratio_threshold, + log_prob_threshold=data.log_prob_threshold, + no_speech_threshold=data.no_speech_threshold, + condition_on_previous_text=data.condition_on_previous_text, + ) + ) + + result = await task + utterances, process_times, audio_duration = result + result = AudioResponse( + utterances=utterances, + audio_duration=audio_duration, + offset_start=data.offset_start, + offset_end=data.offset_end, + num_speakers=data.num_speakers, + diarization=data.diarization, + multi_channel=data.multi_channel, + source_lang=data.source_lang, + timestamps=data.timestamps, + vocab=data.vocab, + word_timestamps=data.word_timestamps, + internal_vad=data.internal_vad, + repetition_penalty=data.repetition_penalty, + compression_ratio_threshold=data.compression_ratio_threshold, + log_prob_threshold=data.log_prob_threshold, + no_speech_threshold=data.no_speech_threshold, + condition_on_previous_text=data.condition_on_previous_text, + job_name=data.job_name, + task_token=data.task_token, + process_times=process_times, + ) + + upload_file( + s3_client, + file=bytes(json.dumps(result.model_dump()).encode("UTF-8")), + bucket=settings.aws_storage_bucket_name, + object_name=f"responses/{data.task_token}_{data.job_name}.json", + ) + + background_tasks.add_task(delete_file, filepath=filepath) + await send_update_with_svix( + data.job_name, + "finished", + { + "job_name": data.job_name, + "task_token": data.task_token, + }, + ) except Exception as e: - raise HTTPException( # noqa: B904 - status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Process failed: {e}", - ) - - background_tasks.add_task(delete_file, filepath=filename) - - task = asyncio.create_task( - asr.process_input( - filepath=filepath, - offset_start=data.offset_start, - offset_end=data.offset_end, - num_speakers=data.num_speakers, - diarization=data.diarization, - multi_channel=data.multi_channel, - source_lang=data.source_lang, - timestamps_format=data.timestamps, - vocab=data.vocab, - word_timestamps=data.word_timestamps, - internal_vad=data.internal_vad, - repetition_penalty=data.repetition_penalty, - compression_ratio_threshold=data.compression_ratio_threshold, - log_prob_threshold=data.log_prob_threshold, - no_speech_threshold=data.no_speech_threshold, - condition_on_previous_text=data.condition_on_previous_text, - ) + error_message = f"Error during transcription: {e}" + logger.error(error_message) + + error_payload = { + "error": error_message, + "job_name": data.job_name, + "task_token": data.task_token, + } + + await send_update_with_svix(data.job_name, "error", error_payload) + + # Add the process_audio function to background tasks + background_tasks.add_task(process_audio) + + # Return the job name and task token immediately + return {"job_name": data.job_name, "task_token": data.task_token} + + +def upload_file(s3_client, file, bucket, object_name): + try: + s3_client.put_object( + Body=file, + Bucket=bucket, + Key=object_name, ) - result = await task + except Exception as e: + logger.error(f"Exception while uploading results to S3: {e}") + return False + return True + - background_tasks.add_task(delete_file, filepath=filepath) +async def send_update_with_svix( + job_name: str, + status: str, + payload: dict, + payload_retention_period: Optional[int] = 5, +) -> None: + """ + Send the status update to Svix. - if isinstance(result, ProcessException): - logger.error(result.message) - raise HTTPException( - status_code=http_status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=str(result.message), + Args: + job_name (str): The name of the job. + status (str): The status of the job. + payload (dict): The payload to send. + payload_retention_period (Optional[int], optional): The payload retention period. Defaults to 5. + """ + if settings.svix_api_key and settings.svix_app_id: + svix = SvixAsync(settings.svix_api_key) + await svix.message.create( + settings.svix_app_id, + MessageIn( + event_type=f"async_job.wordcab_transcribe.{status}", + event_id=f"wordcab_transcribe_{status}_{job_name}_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S_%f')}", + payload_retention_period=payload_retention_period, + payload=payload, + ), ) else: - utterances, process_times, audio_duration = result - return AudioResponse( - utterances=utterances, - audio_duration=audio_duration, - offset_start=data.offset_start, - offset_end=data.offset_end, - num_speakers=data.num_speakers, - diarization=data.diarization, - multi_channel=data.multi_channel, - source_lang=data.source_lang, - timestamps=data.timestamps, - vocab=data.vocab, - word_timestamps=data.word_timestamps, - internal_vad=data.internal_vad, - repetition_penalty=data.repetition_penalty, - compression_ratio_threshold=data.compression_ratio_threshold, - log_prob_threshold=data.log_prob_threshold, - no_speech_threshold=data.no_speech_threshold, - condition_on_previous_text=data.condition_on_previous_text, - process_times=process_times, + logger.warning( + "Svix API key and app ID are not set. Cannot send the status update to" + " Svix." ) diff --git a/src/wordcab_transcribe/router/v1/cortex_endpoint.py b/src/wordcab_transcribe/router/v1/cortex_endpoint.py index 3e1064a..129fafa 100644 --- a/src/wordcab_transcribe/router/v1/cortex_endpoint.py +++ b/src/wordcab_transcribe/router/v1/cortex_endpoint.py @@ -50,7 +50,7 @@ response_model=Union[ CortexError, CortexUrlResponse, CortexYoutubeResponse, PongResponse ], - status_code=http_status.HTTP_200_OK, + status_code=http_status.HTTP_202_ACCEPTED, ) async def run_cortex( payload: CortexPayload, request: Request @@ -137,7 +137,7 @@ async def run_cortex( return CortexError(message=error_message) _cortex_response = { - **response.model_dump(), + **response, "job_name": payload.job_name, "request_id": request_id, } diff --git a/src/wordcab_transcribe/router/v1/youtube_endpoint.py b/src/wordcab_transcribe/router/v1/youtube_endpoint.py index 2a03612..ea6453e 100644 --- a/src/wordcab_transcribe/router/v1/youtube_endpoint.py +++ b/src/wordcab_transcribe/router/v1/youtube_endpoint.py @@ -52,6 +52,8 @@ async def inference_with_youtube( task = asyncio.create_task( asr.process_input( filepath=filepath, + url=url, + url_type="youtube", offset_start=data.offset_start, offset_end=data.offset_end, num_speakers=data.num_speakers, diff --git a/src/wordcab_transcribe/services/alignment/__init__.py b/src/wordcab_transcribe/services/alignment/__init__.py new file mode 100644 index 0000000..9401f40 --- /dev/null +++ b/src/wordcab_transcribe/services/alignment/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2023 The Wordcab Team. All rights reserved. +# +# Licensed under the Wordcab Transcribe License 0.1 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://github.com/Wordcab/wordcab-transcribe/blob/main/LICENSE +# +# Except as expressly provided otherwise herein, and to the fullest +# extent permitted by law, Licensor provides the Software (and each +# Contributor provides its Contributions) AS IS, and Licensor +# disclaims all warranties or guarantees of any kind, express or +# implied, whether arising under any law or from any usage in trade, +# or otherwise including but not limited to the implied warranties +# of merchantability, non-infringement, quiet enjoyment, fitness +# for a particular purpose, or otherwise. +# +# See the License for the specific language governing permissions +# and limitations under the License. +"""Alignment services and all Alignment related code.""" diff --git a/src/wordcab_transcribe/services/alignment/align_service.py b/src/wordcab_transcribe/services/alignment/align_service.py new file mode 100644 index 0000000..fc01445 --- /dev/null +++ b/src/wordcab_transcribe/services/alignment/align_service.py @@ -0,0 +1,664 @@ +""" +Forced Alignment with Whisper +C. Max Bain + +Inspired by: https://github.com/m-bain/whisperX +""" +from dataclasses import dataclass +from typing import Iterable, List, Union + +import numpy as np +import pandas as pd +import torch +import torchaudio +from nltk.tokenize.punkt import PunktParameters, PunktSentenceTokenizer +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + +from .models import AlignedTranscriptionResult, SingleAlignedSegment, SingleWordSegment + +PUNKT_ABBREVIATIONS = ["dr", "vs", "mr", "mrs", "prof"] + +LANGUAGES_WITHOUT_SPACES = ["ja", "zh"] + +DEFAULT_ALIGN_MODELS_TORCH = { + "en": "WAV2VEC2_ASR_BASE_960H", + "fr": "VOXPOPULI_ASR_BASE_10K_FR", + "de": "VOXPOPULI_ASR_BASE_10K_DE", + "es": "VOXPOPULI_ASR_BASE_10K_ES", + "it": "VOXPOPULI_ASR_BASE_10K_IT", +} + +DEFAULT_ALIGN_MODELS_HF = { + "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese", + "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn", + "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch", + "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm", + "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese", + "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic", + "cs": "comodoro/wav2vec2-xls-r-300m-cs-250", + "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian", + "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish", + "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian", + "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish", + "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian", + "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek", + "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish", + "da": "saattrupdan/wav2vec2-xls-r-300m-ftspeech", + "he": "imvladikon/wav2vec2-xls-r-300m-hebrew", + "vi": "nguyenvulebinh/wav2vec2-base-vi", + "ko": "kresnik/wav2vec2-large-xlsr-korean", + "ur": "kingabzpro/wav2vec2-large-xls-r-300m-Urdu", + "te": "anuragshas/wav2vec2-large-xlsr-53-telugu", + "hi": "theainerd/Wav2Vec2-large-xlsr-hindi", +} + + +def fill_missing_words(sentence_data, default_probability=0.5): + """ + Fills in missing word lists in sentence data using estimated timings and default probabilities. + + Parameters: + sentence_data (dict): A dictionary containing the sentence data. + default_probability (float, optional): The default probability to assign to each word. Defaults to 0.5. + + Returns: + dict: The updated sentence data with filled-in words. + """ + + def estimate_word_timings(text, start, end): + """ + Estimates start and end times for each word in the text based on the total duration and character count. + """ + words = text.split() + total_duration = end - start + total_chars = sum(len(word) for word in words) + avg_char_duration = total_duration / total_chars + + word_timings = [] + current_start = start + for word in words: + word_duration = len(word) * avg_char_duration + word_end = current_start + word_duration + word_timings.append({"start": current_start, "end": word_end, "word": word}) + current_start = word_end + + # Adjust the start of the first word and the end of the last word + if word_timings: + word_timings[0]["start"] = start + word_timings[-1]["end"] = end + + return word_timings + + # Check if word list is missing or empty + if "words" not in sentence_data or not sentence_data["words"]: + sentence_text = sentence_data["text"] + sentence_start = sentence_data["start"] + sentence_end = sentence_data["end"] + + # Estimate word timings + estimated_words = estimate_word_timings( + sentence_text, sentence_start, sentence_end + ) + + # Assign default probability to each word + for word_info in estimated_words: + word_info["probability"] = default_probability + + sentence_data["words"] = estimated_words + + return sentence_data + + +def estimate_none_timestamps(timestamp_list): + """ + Estimates missing timestamps in a list of timestamp segments based on the character length of segment times. + + Parameters: + timestamp_list (list): A list of timestamp segments with text. + + Returns: + list: The list with estimated missing timestamps. + """ + total_duration = 0 + total_characters = 0 + + for segment in timestamp_list: + start, end = segment["timestamp"] + if start is not None and end is not None: + duration = end - start + characters = len(segment["text"]) + total_duration += duration + total_characters += characters + + if total_characters > 0: + avg_duration_per_char = total_duration / total_characters + else: + avg_duration_per_char = 0.1 # Default duration per character (assumed) + + for i, segment in enumerate(timestamp_list): + start, end = segment["timestamp"] + characters = len(segment["text"]) + estimated_duration = characters * avg_duration_per_char + + if start is None: + start = ( + timestamp_list[i - 1]["timestamp"][1] + if i > 0 and timestamp_list[i - 1]["timestamp"][1] is not None + else 0 + ) + segment["timestamp"] = (start, start + estimated_duration) + if end is None: + segment["timestamp"] = (start, start + estimated_duration) + return timestamp_list + + +def load_align_model(language_code, device, model_name=None, model_dir=None): + if model_name is None: + # use default model + if language_code in DEFAULT_ALIGN_MODELS_TORCH: + model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code] + elif language_code in DEFAULT_ALIGN_MODELS_HF: + model_name = DEFAULT_ALIGN_MODELS_HF[language_code] + else: + print( + "There is no default alignment model set for this language" + f" ({language_code}). Please find a wav2vec2.0 model" + " finetuned on this language in https://huggingface.co/models, then" + " pass the model name in --align_model [MODEL_NAME]" + ) + raise ValueError(f"No default align-model for language: {language_code}") + + if model_name in torchaudio.pipelines.__all__: + pipeline_type = "torchaudio" + bundle = torchaudio.pipelines.__dict__[model_name] + align_model = bundle.get_model(dl_kwargs={"model_dir": model_dir}).to(device) + labels = bundle.get_labels() + align_dictionary = {c.lower(): i for i, c in enumerate(labels)} + else: + try: + processor = Wav2Vec2Processor.from_pretrained(model_name) + align_model = Wav2Vec2ForCTC.from_pretrained(model_name) + except Exception as e: + print(e) + print( + "Error loading model from huggingface, check" + " https://huggingface.co/models for finetuned wav2vec2.0 models" + ) + raise ValueError from e + pipeline_type = "huggingface" + align_model = align_model.to(device) + labels = processor.tokenizer.get_vocab() + align_dictionary = { + char.lower(): code for char, code in processor.tokenizer.get_vocab().items() + } + + align_metadata = { + "language": language_code, + "dictionary": align_dictionary, + "type": pipeline_type, + } + + return align_model, align_metadata + + +def align( + transcript: Iterable, + model: torch.nn.Module, + align_model_metadata: dict, + audio: Union[np.ndarray, torch.Tensor], + device: str, + sample_rate: int = 16000, + interpolate_method: str = "nearest", + return_char_alignments: bool = False, + print_progress: bool = False, + combined_progress: bool = False, +) -> AlignedTranscriptionResult: + """ + Align phoneme recognition predictions to known transcription. + """ + + transcript = estimate_none_timestamps(transcript) + + if not torch.is_tensor(audio): + audio = torch.from_numpy(audio) + if len(audio.shape) == 1: + audio = audio.unsqueeze(0) + + MAX_DURATION = audio.shape[1] / sample_rate + + model_dictionary = align_model_metadata["dictionary"] + model_lang = align_model_metadata["language"] + model_type = align_model_metadata["type"] + + # 1. Preprocess to keep only characters in dictionary + total_segments = len(transcript) + for sdx, segment in enumerate(transcript): + # strip spaces at beginning / end, but keep track of the amount. + if print_progress: + base_progress = ((sdx + 1) / total_segments) * 100 + percent_complete = ( + (50 + base_progress / 2) if combined_progress else base_progress + ) + print(f"Progress: {percent_complete:.2f}%...") + + num_leading = len(segment["text"]) - len(segment["text"].lstrip()) + num_trailing = len(segment["text"]) - len(segment["text"].rstrip()) + text = segment["text"] + + # split into words + if model_lang not in LANGUAGES_WITHOUT_SPACES: + per_word = text.split(" ") + else: + per_word = text + + clean_char, clean_cdx = [], [] + for cdx, char in enumerate(text): + char_ = char.lower() + # wav2vec2 models use "|" character to represent spaces + if model_lang not in LANGUAGES_WITHOUT_SPACES: + char_ = char_.replace(" ", "|") + + # ignore whitespace at beginning and end of transcript + if cdx < num_leading: + pass + elif cdx > len(text) - num_trailing - 1: + pass + elif char_ in model_dictionary.keys(): + clean_char.append(char_) + clean_cdx.append(cdx) + + clean_wdx = [] + for wdx, wrd in enumerate(per_word): + if any(c in model_dictionary.keys() for c in wrd): + clean_wdx.append(wdx) + + punkt_param = PunktParameters() + punkt_param.abbrev_types = set(PUNKT_ABBREVIATIONS) + sentence_splitter = PunktSentenceTokenizer(punkt_param) + sentence_spans = list(sentence_splitter.span_tokenize(text)) + + segment["clean_char"] = clean_char + segment["clean_cdx"] = clean_cdx + segment["clean_wdx"] = clean_wdx + segment["sentence_spans"] = sentence_spans + + aligned_segments: List[SingleAlignedSegment] = [] + + # 2. Get prediction matrix from alignment model & align + + for sdx, segment in enumerate(transcript): + t1 = segment["timestamp"][0] + t2 = segment["timestamp"][1] + text = segment["text"] + + aligned_seg: SingleAlignedSegment = { + "start": t1, + "end": t2, + "text": text, + "words": [], + } + + if return_char_alignments: + aligned_seg["chars"] = [] + + # check we can align + if len(segment["clean_char"]) == 0: + print( + f'Failed to align segment ("{segment["text"]}"): no characters in this' + " segment found in model dictionary, resorting to original..." + ) + aligned_segments.append(aligned_seg) + continue + + if t1 >= MAX_DURATION: + print( + f'Failed to align segment ("{segment["text"]}"): original start time' + " longer than audio duration, skipping..." + ) + aligned_segments.append(aligned_seg) + continue + + text_clean = "".join(segment["clean_char"]) + tokens = [model_dictionary[c] for c in text_clean] + + f1 = int(t1 * sample_rate) + f2 = int(t2 * sample_rate) + + # TODO: Probably can get some speedup gain with batched inference here + waveform_segment = audio[:, f1:f2] + # Handle the minimum input length for wav2vec2 models + if waveform_segment.shape[-1] < 400: + lengths = torch.as_tensor([waveform_segment.shape[-1]]).to(device) + waveform_segment = torch.nn.functional.pad( + waveform_segment, (0, 400 - waveform_segment.shape[-1]) + ) + else: + lengths = None + + with torch.inference_mode(): + if model_type == "torchaudio": + emissions, _ = model(waveform_segment.to(device), lengths=lengths) + elif model_type == "huggingface": + emissions = model(waveform_segment.to(device)).logits + else: + raise NotImplementedError( + f"Align model of type {model_type} not supported." + ) + emissions = torch.log_softmax(emissions, dim=-1) + + emission = emissions[0].cpu().detach() + + blank_id = 0 + for char, code in model_dictionary.items(): + if char == "[pad]" or char == "": + blank_id = code + + trellis = get_trellis(emission, tokens, blank_id) + path = backtrack(trellis, emission, tokens, blank_id) + + if path is None: + print( + f'Failed to align segment ("{segment["text"]}"): backtrack failed,' + " resorting to original..." + ) + aligned_segments.append(aligned_seg) + continue + + char_segments = merge_repeats(path, text_clean) + + duration = t2 - t1 + ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1) + + # assign timestamps to aligned characters + char_segments_arr = [] + word_idx = 0 + for cdx, char in enumerate(text): + start, end, score = None, None, None + if cdx in segment["clean_cdx"]: + char_seg = char_segments[segment["clean_cdx"].index(cdx)] + start = round(char_seg.start * ratio + t1, 3) + end = round(char_seg.end * ratio + t1, 3) + score = round(char_seg.score, 3) + + char_segments_arr.append( + { + "char": char, + "start": start, + "end": end, + "score": score, + "word-idx": word_idx, + } + ) + + # increment word_idx, nltk word tokenization would probably be more robust here, but us space for now... + if model_lang in LANGUAGES_WITHOUT_SPACES: + word_idx += 1 + elif cdx == len(text) - 1 or text[cdx + 1] == " ": + word_idx += 1 + + char_segments_arr = pd.DataFrame(char_segments_arr) + + aligned_subsegments = [] + # assign sentence_idx to each character index + char_segments_arr["sentence-idx"] = None + for sdx, (sstart, send) in enumerate(segment["sentence_spans"]): + curr_chars = char_segments_arr.loc[ + (char_segments_arr.index >= sstart) & (char_segments_arr.index <= send) + ] + char_segments_arr.loc[ + (char_segments_arr.index >= sstart) & (char_segments_arr.index <= send), + "sentence-idx", + ] = sdx + + sentence_text = text[sstart:send] + sentence_start = curr_chars["start"].min() + sentence_end = curr_chars["end"].max() + sentence_words = [] + avg_char_duration = None + last_end = None + + for ix, word_idx in enumerate(curr_chars["word-idx"].unique()): + word_chars = curr_chars.loc[curr_chars["word-idx"] == word_idx] + word_text = "".join(word_chars["char"].tolist()).strip() + if len(word_text) == 0: + continue + + # Don't use space character for alignment + word_chars = word_chars[word_chars["char"] != " "] + + word_start = word_chars["start"].min() + word_end = word_chars["end"].max() + word_score = round(word_chars["score"].mean(), 3) + + # -1 indicates unalignable + word_segment = {"word": word_text} + + if not np.isnan(word_start): + word_segment["start"] = word_start + if not np.isnan(word_end): + word_segment["end"] = word_end + if not np.isnan(word_score): + word_segment["score"] = word_score + + if "start" not in word_segment or "end" not in word_segment: + if avg_char_duration is None: + df = pd.DataFrame(sentence_words) + df = df.dropna( + subset=["start", "end"] + ) # Drop rows where 'start' or 'end' is NaN + if not df.empty: + df["duration"] = df["end"] - df["start"] + df["char_length"] = df["word"].apply(len) + avg_char_duration = ( + df["duration"] / df["char_length"] + ).mean() + else: + avg_char_duration = ( + 0.1 # Default average character duration + ) + + word_len = len(word_segment["word"]) + estimated_duration = word_len * avg_char_duration + + if "start" not in word_segment: + if len(sentence_words) == 0: + word_segment["start"] = sentence_start + else: + prev_end = sentence_words[len(sentence_words) - 1]["end"] + word_segment["start"] = prev_end + + if "end" not in word_segment: + estimated_end = word_segment["start"] + estimated_duration + if ix == len(sentence_words) - 1: + word_segment["end"] = sentence_end + else: + word_segment["end"] = estimated_end + + if "score" not in word_segment: + word_segment["score"] = 0.5 + + if last_end is not None and word_segment["start"] < last_end: + word_segment["start"] = last_end + last_end = word_segment["end"] + + sentence_words.append(word_segment) + + aligned_subsegments.append( + { + "text": sentence_text, + "start": sentence_start, + "end": sentence_end, + "words": sentence_words, + } + ) + + if return_char_alignments: + curr_chars = curr_chars[["char", "start", "end", "score"]] + curr_chars.fillna(-1, inplace=True) + curr_chars = curr_chars.to_dict("records") + curr_chars = [ + {key: val for key, val in char.items() if val != -1} + for char in curr_chars + ] + aligned_subsegments[-1]["chars"] = curr_chars + + aligned_subsegments = pd.DataFrame(aligned_subsegments) + aligned_subsegments["start"] = interpolate_nans( + aligned_subsegments["start"], method=interpolate_method + ) + aligned_subsegments["end"] = interpolate_nans( + aligned_subsegments["end"], method=interpolate_method + ) + # concatenate sentences with same timestamps + agg_dict = {"text": " ".join, "words": "sum"} + if model_lang in LANGUAGES_WITHOUT_SPACES: + agg_dict["text"] = "".join + if return_char_alignments: + agg_dict["chars"] = "sum" + aligned_subsegments = aligned_subsegments.groupby( + ["start", "end"], as_index=False + ).agg(agg_dict) + aligned_subsegments = aligned_subsegments.to_dict("records") + aligned_segments += aligned_subsegments + + # create word_segments list + word_segments: List[SingleWordSegment] = [] + for segment in aligned_segments: + word_segments += segment["words"] + + return {"segments": aligned_segments, "word_segments": word_segments} + + +def interpolate_nans(x, method="nearest"): + if x.notnull().sum() > 1: + return x.interpolate(method=method).ffill().bfill() + else: + return x.ffill().bfill() + + +""" +source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html +""" + + +def get_trellis(emission, tokens, blank_id=0): + num_frame = emission.size(0) + num_tokens = len(tokens) + + # Trellis has extra diemsions for both time axis and tokens. + # The extra dim for tokens represents (start-of-sentence) + # The extra dim for time axis is for simplification of the code. + trellis = torch.empty((num_frame + 1, num_tokens + 1)) + trellis[0, 0] = 0 + trellis[1:, 0] = torch.cumsum(emission[:, 0], 0) + trellis[0, -num_tokens:] = -float("inf") + trellis[-num_tokens:, 0] = float("inf") + + for t in range(num_frame): + trellis[t + 1, 1:] = torch.maximum( + # Score for staying at the same token + trellis[t, 1:] + emission[t, blank_id], + # Score for changing to the next token + trellis[t, :-1] + emission[t, tokens], + ) + return trellis + + +@dataclass +class Point: + token_index: int + time_index: int + score: float + + +def backtrack(trellis, emission, tokens, blank_id=0): + # Note: + # j and t are indices for trellis, which has extra dimensions + # for time and tokens at the beginning. + # When referring to time frame index `T` in trellis, + # the corresponding index in emission is `T-1`. + # Similarly, when referring to token index `J` in trellis, + # the corresponding index in transcript is `J-1`. + j = trellis.size(1) - 1 + t_start = torch.argmax(trellis[:, j]).item() + + path = [] + for t in range(t_start, 0, -1): + # 1. Figure out if the current position was stay or change + # Note (again): + # `emission[J-1]` is the emission at time frame `J` of trellis dimension. + # Score for token staying the same from time frame J-1 to T. + stayed = trellis[t - 1, j] + emission[t - 1, blank_id] + # Score for token changing from C-1 at T-1 to J at T. + changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]] + + # 2. Store the path with frame-wise probability. + prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item() + # Return token index and time index in non-trellis coordinate. + path.append(Point(j - 1, t - 1, prob)) + + # 3. Update the token + if changed > stayed: + j -= 1 + if j == 0: + break + else: + # failed + return None + return path[::-1] + + +# Merge the labels +@dataclass +class Segment: + label: str + start: int + end: int + score: float + + def __repr__(self): + return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})" + + @property + def length(self): + return self.end - self.start + + +def merge_repeats(path, transcript): + i1, i2 = 0, 0 + segments = [] + while i1 < len(path): + while i2 < len(path) and path[i1].token_index == path[i2].token_index: + i2 += 1 + score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1) + segments.append( + Segment( + transcript[path[i1].token_index], + path[i1].time_index, + path[i2 - 1].time_index + 1, + score, + ) + ) + i1 = i2 + return segments + + +def merge_words(segments, separator="|"): + words = [] + i1, i2 = 0, 0 + while i1 < len(segments): + if i2 >= len(segments) or segments[i2].label == separator: + if i1 != i2: + segs = segments[i1:i2] + word = "".join([seg.label for seg in segs]) + score = sum(seg.score * seg.length for seg in segs) / sum( + seg.length for seg in segs + ) + words.append( + Segment(word, segments[i1].start, segments[i2 - 1].end, score) + ) + i1 = i2 + 1 + i2 = i1 + else: + i2 += 1 + return words diff --git a/src/wordcab_transcribe/services/alignment/models.py b/src/wordcab_transcribe/services/alignment/models.py new file mode 100644 index 0000000..90b897c --- /dev/null +++ b/src/wordcab_transcribe/services/alignment/models.py @@ -0,0 +1,63 @@ +from typing import List, Optional, TypedDict + + +class SingleWordSegment(TypedDict): + """ + A single word of a speech. + """ + + word: str + start: float + end: float + score: float + + +class SingleCharSegment(TypedDict): + """ + A single char of a speech. + """ + + char: str + start: float + end: float + score: float + + +class SingleSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech. + """ + + start: float + end: float + text: str + + +class SingleAlignedSegment(TypedDict): + """ + A single segment (up to multiple sentences) of a speech with word alignment. + """ + + start: float + end: float + text: str + words: List[SingleWordSegment] + chars: Optional[List[SingleCharSegment]] + + +class TranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + + segments: List[SingleSegment] + language: str + + +class AlignedTranscriptionResult(TypedDict): + """ + A list of segments and word segments of a speech. + """ + + segments: List[SingleAlignedSegment] + word_segments: List[SingleWordSegment] diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index 3e858a2..e062375 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -26,7 +26,7 @@ from dataclasses import dataclass from enum import Enum from pathlib import Path -from typing import Iterable, List, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union import aiohttp import torch @@ -35,6 +35,7 @@ from tensorshare import Backend, TensorShare from typing_extensions import Literal +from wordcab_transcribe.config import settings from wordcab_transcribe.logging import time_and_tell, time_and_tell_async from wordcab_transcribe.models import ( DiarizationOutput, @@ -90,6 +91,8 @@ class ASRTask(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) audio: Union[torch.Tensor, List[torch.Tensor]] + url: Union[str, None] + url_type: Union[str, None] diarization: "DiarizationTask" duration: float multi_channel: bool @@ -209,7 +212,8 @@ def __init__(self) -> None: @abstractmethod async def process_input(self) -> None: - """Process the input request by creating a task and adding it to the appropriate queues.""" + """Process the input request by creating a task and adding it to the appropriate queues. + """ raise NotImplementedError("This method should be implemented in subclasses.") @@ -264,9 +268,9 @@ def __init__( self.shift_lengths: List[float] = shift_lengths self.multiscale_weights: List[float] = multiscale_weights self.extra_languages: Union[List[str], None] = extra_languages - self.extra_languages_model_paths: Union[List[str], None] = ( - extra_languages_model_paths - ) + self.extra_languages_model_paths: Union[ + List[str], None + ] = extra_languages_model_paths self.local_services: LocalServiceRegistry = LocalServiceRegistry() self.remote_services: RemoteServiceRegistry = RemoteServiceRegistry() @@ -388,6 +392,8 @@ async def process_input( # noqa: C901 log_prob_threshold: float, no_speech_threshold: float, condition_on_previous_text: bool, + url: Optional[str] = None, + url_type: Optional[str] = None, ) -> Union[Tuple[List[dict], ProcessTimes, float], Exception]: """Process the input request and return the results. @@ -481,6 +487,8 @@ async def process_input( # noqa: C901 task = ASRTask( audio=audio, + url=url, + url_type=url_type, diarization=DiarizationTask( execution=diarization_execution, num_speakers=num_speakers ), @@ -644,10 +652,18 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: result, process_time = out elif isinstance(task.diarization.execution, RemoteExecution): - ts = TensorShare.from_dict({"audio": task.audio}, backend=Backend.TORCH) + if task.url: + audio = task.url + audio_type = task.url_type + else: + audio = TensorShare.from_dict( + {"audio": task.audio}, backend=Backend.TORCH + ) + audio_type = "tensor" data = DiarizationRequest( - audio=ts, + audio=audio, + audio_type=audio_type, duration=task.duration, num_speakers=task.diarization.num_speakers, ) @@ -776,11 +792,31 @@ async def remote_diarization( data: DiarizationRequest, ) -> DiarizationOutput: """Remote diarization method.""" - async with aiohttp.ClientSession() as session: + headers = {"Content-Type": "application/json"} + + if not settings.debug: + headers = {"Content-Type": "application/x-www-form-urlencoded"} + auth_url = f"{url}/api/v1/auth" + async with aiohttp.ClientSession() as session: + async with session.post( + url=auth_url, + data={"username": settings.username, "password": settings.password}, + headers=headers, + ) as response: + if response.status != 200: + raise Exception(response.status) + else: + token = await response.json() + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {token['access_token']}", + } + diarization_timeout = aiohttp.ClientTimeout(total=1200) + async with aiohttp.ClientSession(timeout=diarization_timeout) as session: async with session.post( url=f"{url}/api/v1/diarize", data=data.model_dump_json(), - headers={"Content-Type": "application/json"}, + headers=headers, ) as response: if response.status != 200: r = await response.json() @@ -1044,6 +1080,7 @@ async def inference_warmup(self) -> None: data = DiarizationRequest( audio=ts, + audio_type="tensor", duration=duration, num_speakers=1, ) @@ -1067,13 +1104,28 @@ async def process_input(self, data: DiarizationRequest) -> DiarizationOutput: gpu_index = await self.gpu_handler.get_device() try: - result = self.diarization_service( - waveform=data.audio, - audio_duration=data.duration, - oracle_num_speakers=data.num_speakers, - model_index=gpu_index, - vad_service=self.vad_service, - ) + if data.audio_type == "tensor": + result = self.diarization_service( + waveform=data.audio, + audio_duration=data.duration, + oracle_num_speakers=data.num_speakers, + model_index=gpu_index, + vad_service=self.vad_service, + ) + elif data.audio_type and data.audio_type in ["youtube", "url"]: + result = self.diarization_service( + url=data.audio, + url_type=data.audio_type, + audio_duration=data.duration, + oracle_num_speakers=data.num_speakers, + model_index=gpu_index, + vad_service=self.vad_service, + ) + else: + raise ValueError( + f"Invalid audio type: {data.audio_type}. " + "Must be one of ['tensor', 'youtube', 'url']." + ) except Exception as e: result = ProcessException( diff --git a/src/wordcab_transcribe/services/diarization/diarize_service.py b/src/wordcab_transcribe/services/diarization/diarize_service.py index 27f5c90..29a04c1 100644 --- a/src/wordcab_transcribe/services/diarization/diarize_service.py +++ b/src/wordcab_transcribe/services/diarization/diarize_service.py @@ -18,10 +18,11 @@ # See the License for the specific language governing permissions # and limitations under the License. """Diarization Service for audio files.""" - -from typing import List, NamedTuple, Tuple, Union +import os +from typing import List, NamedTuple, Optional, Tuple, Union import torch +from loguru import logger from tensorshare import Backend, TensorShare from wordcab_transcribe.models import DiarizationOutput @@ -33,6 +34,12 @@ SegmentationModule, ) from wordcab_transcribe.services.vad_service import VadService +from wordcab_transcribe.utils import ( + delete_file, + download_audio_file_sync, + process_audio_file_sync, + read_audio, +) class DiarizationModels(NamedTuple): @@ -75,13 +82,17 @@ def __init__( self.default_shift_lengths = shift_lengths self.default_multiscale_weights = multiscale_weights - if len(self.default_multiscale_weights) > 3: - self.default_segmentation_batch_size = 64 - elif len(self.default_multiscale_weights) > 1: - self.default_segmentation_batch_size = 128 + self.seg_batch_size = os.getenv("DIARIZATION_SEGMENTATION_BATCH_SIZE", None) + if self.seg_batch_size is not None: + self.default_segmentation_batch_size = int(self.seg_batch_size) else: - self.default_segmentation_batch_size = 256 - + if len(self.default_multiscale_weights) > 3: + self.default_segmentation_batch_size = 64 + elif len(self.default_multiscale_weights) > 1: + self.default_segmentation_batch_size = 128 + else: + self.default_segmentation_batch_size = 256 + logger.info(f"segmentation_batch_size set to {self.seg_batch_size}") self.default_scale_dict = dict(enumerate(zip(window_lengths, shift_lengths))) for idx in device_index: @@ -98,11 +109,13 @@ def __init__( def __call__( self, - waveform: Union[torch.Tensor, TensorShare], audio_duration: float, oracle_num_speakers: int, model_index: int, vad_service: VadService, + waveform: Optional[Union[torch.Tensor, TensorShare]] = None, + url: Optional[str] = None, + url_type: Optional[str] = None, ) -> DiarizationOutput: """ Run inference with the diarization model. @@ -123,9 +136,21 @@ def __call__( DiarizationOutput: List of segments with the following keys: "start", "end", "speaker". """ - if isinstance(waveform, TensorShare): + if url and url_type: + import shortuuid + + filename = f"audio_{shortuuid.ShortUUID().random(length=32)}" + filepath = download_audio_file_sync(url_type, url, filename) + filepath = process_audio_file_sync(filepath) + waveform, _ = read_audio(filepath) + delete_file(filepath) + elif isinstance(waveform, TensorShare): ts = waveform.to_tensors(backend=Backend.TORCH) waveform = ts["audio"] + elif isinstance(waveform, torch.Tensor): + pass + else: + return None vad_outputs, _ = vad_service(waveform, group_timestamps=False) @@ -145,11 +170,17 @@ def __call__( ) ) ) - segmentation_batch_size = 64 + if self.seg_batch_size: + segmentation_batch_size = int(self.seg_batch_size) + else: + segmentation_batch_size = 64 multiscale_weights = self.default_multiscale_weights else: scale_dict = dict(enumerate(zip([3.0, 2.0, 1.0], [0.75, 0.5, 0.25]))) - segmentation_batch_size = 32 + if self.seg_batch_size: + segmentation_batch_size = int(self.seg_batch_size) + else: + segmentation_batch_size = 32 multiscale_weights = [1.0, 1.0, 1.0] ms_emb_ts: MultiscaleEmbeddingsAndTimestamps = self.models[ diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index 4a0fa9b..1aa4983 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions # and limitations under the License. """Transcribe Service for audio files.""" - +import os from typing import Iterable, List, NamedTuple, Optional, Union import torch @@ -29,9 +29,15 @@ from wordcab_transcribe.models import ( MultiChannelSegment, MultiChannelTranscriptionOutput, + Segment, TranscriptionOutput, Word, ) +from wordcab_transcribe.services.alignment.align_service import ( + align, + estimate_none_timestamps, + load_align_model, +) class FasterWhisperModel(NamedTuple): @@ -75,12 +81,61 @@ def __init__( self.compute_type = compute_type self.model_path = model_path - self.model = WhisperModel( - self.model_path, - device=self.device, - device_index=device_index, - compute_type=self.compute_type, - ) + whisper_engine = os.getenv("WHISPER_ENGINE", "faster-whisper") + if whisper_engine == "faster-whisper": + self.model = WhisperModel( + self.model_path, + device=self.device, + device_index=device_index, + compute_type=self.compute_type, + ) + else: + from transformers import ( + AutoModelForCausalLM, + AutoModelForSpeechSeq2Seq, + AutoProcessor, + pipeline, + ) + + model_id = os.getenv("WHISPER_TEACHER_MODEL", "openai/whisper-medium.en") + logger.info(f"WHISPER_TEACHER_MODEL set to {model_id}") + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, + torch_dtype=torch.float16, + low_cpu_mem_usage=False, + use_safetensors=False, + use_flash_attention_2=False, + ) + model.to(device) + + self.processor = AutoProcessor.from_pretrained(model_id) + + assistant_model_id = os.getenv( + "DISTIL_WHISPER_ASSISTANT_MODEL", "distil-whisper/distil-medium.en" + ) + logger.info(f"DISTIL_WHISPER_ASSISTANT_MODEL set to {assistant_model_id}") + assistant_model = AutoModelForCausalLM.from_pretrained( + assistant_model_id, + torch_dtype=torch.float16, + low_cpu_mem_usage=False, + use_safetensors=False, + use_flash_attention_2=False, + ) + assistant_model.to(device) + + self.model = pipeline( + "automatic-speech-recognition", + model=model, + tokenizer=self.processor.tokenizer, + feature_extractor=self.processor.feature_extractor, + max_new_tokens=128, + chunk_length_s=30, + torch_dtype=torch.float16, + generate_kwargs={"assistant_model": assistant_model}, + device="cuda", + ) + self.align_model, self.align_model_metadata = load_align_model("en", "cuda") + self.align = align self.extra_lang = extra_languages self.extra_lang_models = extra_languages_model_paths @@ -191,32 +246,8 @@ def __call__( ts = audio.to_tensors(backend=Backend.NUMPY) audio = ts["audio"] - segments, _ = self.model.transcribe( - audio, - language=source_lang, - initial_prompt=prompt, - repetition_penalty=repetition_penalty, - compression_ratio_threshold=compression_ratio_threshold, - log_prob_threshold=log_prob_threshold, - no_speech_threshold=no_speech_threshold, - condition_on_previous_text=condition_on_previous_text, - suppress_blank=suppress_blank, - word_timestamps=word_timestamps, - vad_filter=internal_vad, - vad_parameters={ - "threshold": 0.5, - "min_speech_duration_ms": 250, - "min_silence_duration_ms": 100, - "speech_pad_ms": 30, - "window_size_samples": 512, - }, - ) - - segments = list(segments) - if not segments: - logger.warning( - "Empty transcription result. Trying with vad_filter=True." - ) + whisper_engine = os.getenv("WHISPER_ENGINE", "faster-whisper") + if whisper_engine == "faster-whisper": segments, _ = self.model.transcribe( audio, language=source_lang, @@ -226,10 +257,66 @@ def __call__( log_prob_threshold=log_prob_threshold, no_speech_threshold=no_speech_threshold, condition_on_previous_text=condition_on_previous_text, - suppress_blank=False, - word_timestamps=True, - vad_filter=False if internal_vad else True, + suppress_blank=suppress_blank, + word_timestamps=word_timestamps, + vad_filter=internal_vad, + vad_parameters={ + "threshold": 0.5, + "min_speech_duration_ms": 250, + "min_silence_duration_ms": 100, + "speech_pad_ms": 30, + "window_size_samples": 512, + }, + ) + else: + segments = [] + batch_size = os.getenv("WHISPER_BATCH_SIZE", 8) + logger.info(f"WHISPER_BATCH_SIZE set to {batch_size}") + outputs = self.model( + audio, return_timestamps=True, batch_size=int(batch_size) ) + for output in outputs["chunks"]: + output["text"] = output["text"].strip() + segments.append(output) + + segments = estimate_none_timestamps(segments) + + # segments = self.align( + # transcript=segments, + # align_model_metadata=self.align_model_metadata, + # model=self.align_model, + # audio=audio, + # device="cuda", + # )["segments"] + + for ix, segment in enumerate(segments): + # for _ix, word in enumerate(segment["words"]): + # word = { + # "start": word.pop("start"), + # "end": word.pop("end"), + # "word": word.pop("word"), + # "probability": word.pop("score") + # } + # segment["words"][_ix] = word + # if not segment["words"]: + # segment = fill_missing_words(segment) + # segment["start"] = segment["words"][0]["start"] + # segment["end"] = segment["words"][-1]["end"] + # segment["text"] = " ".join([word["word"].strip() for word in segment["words"]]).strip() + extra = { + "seek": 1, + "id": 1, + "tokens": [1], + "temperature": 0.0, + "avg_logprob": 0.0, + "compression_ratio": 0.0, + "no_speech_prob": 0.0, + } + segments[ix]["start"] = segment["timestamp"][0] + segments[ix]["end"] = segment["timestamp"][1] + segments[ix].pop("timestamp") + segments[ix]["words"] = [] + segments[ix] = Segment(**{**segment, **extra}) _outputs = [segment._asdict() for segment in segments] outputs = TranscriptionOutput(segments=_outputs) diff --git a/src/wordcab_transcribe/utils.py b/src/wordcab_transcribe/utils.py index fcb752a..64b02e0 100644 --- a/src/wordcab_transcribe/utils.py +++ b/src/wordcab_transcribe/utils.py @@ -29,6 +29,7 @@ import aiofiles import aiohttp import huggingface_hub +import requests import soundfile as sf import torch import torchaudio @@ -221,6 +222,36 @@ async def download_audio_file( return filename +# pragma: no cover +def download_audio_file_sync( + source: str, + url: str, + filename: str, +) -> Union[str, Awaitable[str]]: + """ + Download an audio file from a URL. + + Args: + source (str): Source of the audio file. Can be "youtube" or "url". + url (str): URL of the audio file. + filename (str): Filename to save the file as. + + Raises: + ValueError: If the source is invalid. Valid sources are: youtube, url. + + Returns: + Union[str, Awaitable[str]]: Path to the downloaded file. + """ + if source == "youtube": + filename = _download_file_from_youtube(url, filename) + elif source == "url": + filename = _download_file_from_url_sync(url, filename) + else: + raise ValueError(f"Invalid source: {source}. Valid sources are: youtube, url.") + + return filename + + # pragma: no cover def _download_file_from_youtube(url: str, filename: str) -> str: """ @@ -233,6 +264,7 @@ def _download_file_from_youtube(url: str, filename: str) -> str: Returns: str: Path to the downloaded file. """ + logger.info(f"Downloading YouTube file from {url} to {filename}...") with YoutubeDL( { "format": "bestaudio", @@ -286,6 +318,40 @@ async def _download_file_from_url( return filename +def _download_file_from_url_sync( + url: str, filename: str, url_headers: Optional[Dict[str, str]] = None +) -> str: + """ + Download a file from a URL using requests. + + Args: + url (str): URL of the audio file. + filename (str): Filename to save the file as. + url_headers (Optional[Dict[str, str]]): Headers to send with the request. Defaults to None. + + Returns: + str: Path to the downloaded file. + + Raises: + Exception: If the file failed to download. + """ + url_headers = url_headers or {} + + logger.info(f"Downloading audio file from {url} to {filename}...") + + response = requests.get(url, headers=url_headers, stream=True) + + if response.status_code == 200: + with open(filename, "wb") as f: + for chunk in response.iter_content(chunk_size=1024): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + else: + raise Exception(f"Failed to download file. Status: {response.status_code}") + + return filename + + # pragma: no cover def download_model(compute_type: str, language: str) -> Optional[str]: """ @@ -524,6 +590,52 @@ async def process_audio_file( return output_files +def process_audio_file_sync(filepath: str) -> Union[str, List[str]]: + """Prepare the audio for inference. + + Process an audio file using ffmpeg. The file will be converted to WAV. + The codec used is pcm_s16le and the sample rate is 16000. + + Args: + filepath (str): + Path to the file to process. + + Raises: + FileNotFoundError: If the file does not exist. + Exception: If there's an error in processing. + + Returns: + Union[str, List[str]]: Path to the converted/split files. + """ + _filepath = Path(filepath) + + if not _filepath.exists(): + raise FileNotFoundError(f"File {filepath} does not exist.") + + new_filepath = f"{_filepath.stem}_{_filepath.stat().st_mtime_ns}.wav" + cmd = [ + "ffmpeg", + "-i", + filepath, + "-vn", + "-acodec", + "pcm_s16le", + "-ar", + "16000", + "-ac", + "1", + "-y", + new_filepath, + ] + + result = run_subprocess(cmd) + + if result[0] != 0: + raise Exception(f"Error converting file {filepath} to wav format: {result[2]}") + + return new_filepath + + def read_audio( audio: Union[str, bytes], offset_start: Union[float, None] = None,