From 7900a2f3d19922703e28a0476722c6a538e9faf7 Mon Sep 17 00:00:00 2001 From: aleks Date: Fri, 31 Mar 2023 08:03:08 -0400 Subject: [PATCH] added source_lang setting, option to change timestamp format --- .dockerignore | 1 + Dockerfile | 3 +- requirements.txt | 1 + wordcab_transcribe/main.py | 43 +++++++---- wordcab_transcribe/service.py | 141 ++++++++++++++++++++++++---------- 5 files changed, 133 insertions(+), 56 deletions(-) create mode 100644 .dockerignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d6d95cf --- /dev/null +++ b/.dockerignore @@ -0,0 +1 @@ +models \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 0434322..77bed91 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,5 @@ FROM nvidia/cuda:11.7.0-devel-ubuntu22.04 -COPY requirements.txt /requirements.txt RUN apt-get update && apt-get install -y \ git \ curl \ @@ -14,6 +13,8 @@ RUN add-apt-repository ppa:deadsnakes/ppa \ RUN python3.10 -m pip install -r requirements.txt RUN python3.10 -m pip install --upgrade torch==1.13.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 +COPY requirements.txt /requirements.txt + COPY . /app WORKDIR /app diff --git a/requirements.txt b/requirements.txt index 28db089..2e27a00 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,6 @@ pydantic>=1.10.7 python-dotenv>=1.0.0 python-multipart>=0.0.6 scikit-learn>=1.2.2 +shortuuid>=1.0.0 uvicorn>=0.21.1 yt-dlp>=2023.3.4 diff --git a/wordcab_transcribe/main.py b/wordcab_transcribe/main.py index ec84658..071322a 100644 --- a/wordcab_transcribe/main.py +++ b/wordcab_transcribe/main.py @@ -11,16 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + """Main API module of the Wordcab Transcribe.""" -import aiofiles -import asyncio import random +import asyncio + +import aiofiles +import shortuuid from loguru import logger +from typing import Optional -from fastapi import BackgroundTasks, FastAPI, File, UploadFile from fastapi import status as http_status from fastapi.responses import HTMLResponse +from fastapi import BackgroundTasks, FastAPI, File, UploadFile from wordcab_transcribe.config import settings from wordcab_transcribe.models import ASRResponse @@ -28,6 +32,7 @@ from wordcab_transcribe.utils import convert_file_to_wav, delete_file, download_file_from_youtube + app = FastAPI( title=settings.project_name, version=settings.version, @@ -79,7 +84,9 @@ async def health_check(): async def inference_with_audio( background_tasks: BackgroundTasks, file: UploadFile = File(...), - num_speakers: int | None = None, + num_speakers: Optional[int] = 0, + source_lang: Optional[str] = "en", + timestamps: Optional[str] = "seconds", ): """ Inference endpoint. @@ -87,7 +94,12 @@ async def inference_with_audio( Args: background_tasks (BackgroundTasks): Background tasks dependency. file (UploadFile): Audio file. - num_speakers (int): Number of speakers in the audio file. Default: 0. + num_speakers (int): Number of speakers to detect; defaults to 0, which + attempts to detect the number of speaker. + source_lang (str): The language of the source file; defaults to "en". + timestamps (str): The format of the transcript timestamps. Options + are "seconds", "milliseconds", or "hms," which stands for hours, + minutes, seconds. Defaults to "seconds". Returns: ASRResponse: Response data. @@ -101,10 +113,9 @@ async def inference_with_audio( response = requests.post("url/api/v1/audio", files=files) print(response.json()) """ - num_speakers = num_speakers or 0 extension = file.filename.split(".")[-1] + filename = f"audio_{shortuuid.ShortUUID().random(length=32)}.{extension}" - filename = f"audio_{''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=32))}.{extension}" async with aiofiles.open(filename, "wb") as f: audio_bytes = await file.read() await f.write(audio_bytes) @@ -115,12 +126,12 @@ async def inference_with_audio( else: filepath = filename - utterances = await asr.process_input(filepath, num_speakers) + utterances = await asr.process_input(filepath, num_speakers, source_lang, timestamps) utterances = [ { - "start": float(utterance["start"]), "text": str(utterance["text"]), - "end": float(utterance["end"]), + "start": utterance["start"], + "end": utterance["end"], "speaker": int(utterance["speaker"]), } for utterance in utterances @@ -140,7 +151,9 @@ async def inference_with_audio( async def inference_with_youtube( background_tasks: BackgroundTasks, url: str, - num_speakers: int | None = None, + num_speakers: Optional[int] = 0, + source_lang: Optional[str] = "en", + timestamps: Optional[str] = "seconds", ): """ Inference endpoint. @@ -161,15 +174,15 @@ async def inference_with_youtube( """ num_speakers = num_speakers or 0 - filename = f"audio_{''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=32))}" + filename = f"yt_{shortuuid.ShortUUID().random(length=32)}" filepath = await download_file_from_youtube(url, filename) - utterances = await asr.process_input(filepath, num_speakers) + utterances = await asr.process_input(filepath, num_speakers, source_lang, timestamps) utterances = [ { - "start": float(utterance["start"]), "text": str(utterance["text"]), - "end": float(utterance["end"]), + "start": utterance["start"], + "end": utterance["end"], "speaker": int(utterance["speaker"]), } for utterance in utterances diff --git a/wordcab_transcribe/service.py b/wordcab_transcribe/service.py index e87d470..d31741a 100644 --- a/wordcab_transcribe/service.py +++ b/wordcab_transcribe/service.py @@ -13,54 +13,79 @@ # limitations under the License. """Service module to handle AI model interactions.""" +import io +import math import asyncio import functools -import io +from pathlib import Path + import numpy as np from loguru import logger -from typing import List +from typing import List, Optional + +from wordcab_transcribe.config import settings +from wordcab_transcribe.utils import format_segments import torch -from faster_whisper import WhisperModel +from sklearn.metrics import silhouette_score +from sklearn.cluster import AgglomerativeClustering from pyannote.audio import Audio -from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding from pyannote.core import Segment +from pyannote.audio.pipelines.speaker_verification import PretrainedSpeakerEmbedding -from sklearn.metrics import silhouette_score -from sklearn.cluster import AgglomerativeClustering +from faster_whisper import WhisperModel -from wordcab_transcribe.config import settings -from wordcab_transcribe.utils import format_segments -class ASRService(): +class ASRService: def __init__( self, model_size: str = "large-v2", - embds_model: str = "speechbrain/spkrec-ecapa-voxceleb", + whisper_dir: str = "models/whisper_model", + compute_type: str = "int8_float16", + embeddings_model: str = "speechbrain/spkrec-ecapa-voxceleb", ) -> None: """ ASR Service class to handle AI model interactions. Args: model_size (str, optional): Model size to use. Defaults to "large-v2". - embds_model (str, optional): Speaker embeddings model to use. + whisper_dir (str, optional): If mounting the Whisper model, this directory will be used to load the model. + embeddings_model (str, optional): Speaker embeddings model to use. Defaults to "speechbrain/spkrec-ecapa-voxceleb". """ self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model_size = model_size - self.embds_model = embds_model + self.whisper_dir = whisper_dir + self.compute_type = compute_type + self.embeddings_model = embeddings_model + + if Path(self.whisper_dir).exists(): + try: + self.model = WhisperModel( + self.whisper_dir, + device=self.device, + compute_type=self.compute_type, + model_dir=whisper_dir + ) + except: + logger.error("Failed to load Whisper model from directory. Downloading model...") + self.model = WhisperModel( + self.model_size, + device=self.device, + compute_type=self.compute_type + ) + else: + self.model = WhisperModel( + self.model_size, + device=self.device, + compute_type=self.compute_type + ) - self.model = WhisperModel( - self.model_size, - device=self.device, - compute_type="int8_float16" - ) self.embedding_model = PretrainedSpeakerEmbedding( - self.embds_model, + self.embeddings_model, device=self.device ) @@ -84,7 +109,7 @@ def schedule_processing_if_needed(self) -> None: ) - async def process_input(self, filepath: str, num_speakers: int) -> List[dict]: + async def process_input(self, filepath: str, num_speakers: int, source_lang: str, timestamps: str) -> List[dict]: """ Process the input request and return the result. @@ -99,6 +124,8 @@ async def process_input(self, filepath: str, num_speakers: int) -> List[dict]: "done_event": asyncio.Event(), "input": filepath, "num_speakers": num_speakers, + "source_lang": source_lang, + "timestamps": timestamps, "time": asyncio.get_event_loop().time(), } async with self.queue_lock: @@ -132,11 +159,14 @@ async def runner(self) -> None: self.schedule_processing_if_needed() try: - batch = [(task["input"], task["num_speakers"]) for task in file_batch] + batch = [ + (task["input"], task["num_speakers"], task["source_lang"], task["timestamps"]) + for task in file_batch + ] results = [] - for input_file, num_speakers in batch: + for input_file, num_speakers, source_lang, timestamps in batch: res = await asyncio.get_event_loop().run_in_executor( - None, functools.partial(self.inference, input_file, num_speakers) + None, functools.partial(self.inference, input_file, num_speakers, source_lang, timestamps) ) results.append(res) for task, result in zip(file_batch, results): @@ -152,38 +182,57 @@ async def runner(self) -> None: task["done_event"].set() - def inference(self, filepath: str, num_speakers: int) -> List[dict]: + def convert_seconds_to_hms(self, seconds): + hours, remainder = divmod(seconds, 3600) + minutes, seconds = divmod(remainder, 60) + milliseconds = math.floor((seconds % 1) * 1000) + output = f"{int(hours):02}:{int(minutes):02}:{int(seconds):02},{milliseconds:03}" + return output + + + def inference( + self, + filepath: str, + num_speakers: int, + source_lang: str, + timestamps: str, + ) -> List[dict]: """ Inference method to process the audio file. Args: filepath (str): Path to the audio file. - num_speakers (int): Number of speakers to detect. + Returns: List[dict]: List of diarized segments. """ - segments, _ = self.model.transcribe(filepath, language="en", beam_size=5, word_timestamps=True) + + segments, _ = self.model.transcribe(filepath, language=source_lang, beam_size=5, word_timestamps=True) segments = format_segments(list(segments)) duration = segments[-1]["end"] - - diarized_segments = self.diarize(filepath, segments, duration, num_speakers) + diarized_segments = self.diarize(filepath, segments, duration, num_speakers, timestamps) return diarized_segments def diarize( - self, audio_obj: io.BytesIO, segments: List[dict], duration: float, num_speakers: int = None + self, audio_obj: str, + segments: List[dict], + duration: float, + num_speakers: int, + timestamps: str, ) -> List[dict]: """ Diarize the segments using pyannote. Args: - audio_obj (io.BytesIO): Audio file object. + audio_obj (str): Path to the audio file. segments (List[dict]): List of segments to diarize. duration (float): Duration of the audio file. - num_speakers (int, optional): Number of speakers. Defaults to None. + num_speakers (int): Number of speakers; defaults to 0. + timestamps (str): Format of timestamps; defaults to "seconds". Returns: List[dict]: List of diarized segments with speaker labels. @@ -199,8 +248,7 @@ def diarize( best_num_speakers = self._get_num_speakers(embeddings, num_speakers) identified_segments = self._assign_speaker_label(segments, embeddings, best_num_speakers) - - joined_segments = self.join_utterances(identified_segments) + joined_segments = self.join_utterances(identified_segments, timestamps) return joined_segments @@ -228,7 +276,7 @@ def segment_embedding(self, audio_obj: io.BytesIO, segment: dict, duration: floa return self.embedding_model(waveform[None]) - def join_utterances(self, segments: List[dict]) -> List[dict]: + def join_utterances(self, segments: List[dict], timestamps: str) -> List[dict]: """ Join the segments of the same speaker. @@ -262,6 +310,17 @@ def join_utterances(self, segments: List[dict]) -> List[dict]: current_utterance["text"] = text.strip() utterance_list.append(current_utterance) + for utterance in utterance_list: + if timestamps == "hms": + utterance["start"] = self.convert_seconds_to_hms(utterance["start"]) + utterance["end"] = self.convert_seconds_to_hms(utterance["end"]) + elif timestamps == "seconds": + utterance["start"] = float(utterance["start"]) + utterance["end"] = float(utterance["end"]) + elif timestamps == "milliseconds": + utterance["start"] = float(utterance["start"] * 1000) + utterance["end"] = float(utterance["end"] * 1000) + return utterance_list @@ -285,7 +344,6 @@ def _get_num_speakers(self, embeddings: np.ndarray, num_speakers: int) -> int: score_num_speakers[i] = score best_num_speakers = max(score_num_speakers, key=lambda x: score_num_speakers[x]) - else: best_num_speakers = num_speakers @@ -306,10 +364,13 @@ def _assign_speaker_label( Returns: List[int]: List of segments with speaker labels. """ - clustering = AgglomerativeClustering(best_num_speakers).fit(embeddings) - labels = clustering.labels_ - - for i in range(len(segments)): - segments[i]["speaker"] = labels[i] + 1 + if best_num_speakers == 1: + for i in range(len(segments)): + segments[i]["speaker"] = 1 + else: + clustering = AgglomerativeClustering(best_num_speakers).fit(embeddings) + labels = clustering.labels_ + for i in range(len(segments)): + segments[i]["speaker"] = labels[i] + 1 return segments