diff --git a/src/wordcab_transcribe/logging.py b/src/wordcab_transcribe/logging.py index d5d306c..e1d0fcb 100644 --- a/src/wordcab_transcribe/logging.py +++ b/src/wordcab_transcribe/logging.py @@ -95,7 +95,7 @@ def time_and_tell( The appropriate wrapper for the function. """ start_time = time.time() - result = func() + result = func process_time = time.time() - start_time if debug_mode: @@ -121,11 +121,13 @@ async def time_and_tell_async( start_time = time.time() if asyncio.iscoroutinefunction(func) or asyncio.iscoroutine(func): - result = await func() + result = await func else: loop = asyncio.get_event_loop() if isinstance(func, partial): - result = await loop.run_in_executor(None, func.func, *func.args, **func.keywords) + result = await loop.run_in_executor( + None, func.func, *func.args, **func.keywords + ) else: result = await loop.run_in_executor(None, func) diff --git a/src/wordcab_transcribe/models.py b/src/wordcab_transcribe/models.py index a33f079..cf1df0d 100644 --- a/src/wordcab_transcribe/models.py +++ b/src/wordcab_transcribe/models.py @@ -22,7 +22,7 @@ from enum import Enum from typing import List, Literal, NamedTuple, Optional, Union -from faster_whisper.transcribe import Segment, Word +from faster_whisper.transcribe import Segment from pydantic import BaseModel, field_validator from tensorshare import TensorShare @@ -44,6 +44,15 @@ class Timestamps(str, Enum): hour_minute_second = "hms" +class Word(BaseModel): + """Word model for the API.""" + + word: str + start: float + end: float + probability: float + + class Utterance(BaseModel): """Utterance model for the API.""" diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index 7565b37..6860e11 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -469,7 +469,7 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: try: if isinstance(task.transcription.execution, LocalExecution): out = await time_and_tell_async( - lambda: self.services["transcription"]( + self.services["transcription"]( task.audio, model_index=task.transcription.execution.index, suppress_blank=False, @@ -497,7 +497,7 @@ async def process_transcription(self, task: ASRTask, debug_mode: bool) -> None: **task.transcription.options.model_dump(), ) out = await time_and_tell_async( - lambda: self.remote_transcription( + self.remote_transcription( url=task.transcription.execution.url, data=data, ), @@ -536,7 +536,7 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: try: if isinstance(task.diarization.execution, LocalExecution): out = await time_and_tell_async( - lambda: self.services["diarization"]( + self.services["diarization"]( waveform=task.audio, audio_duration=task.duration, oracle_num_speakers=task.diarization.num_speakers, @@ -557,7 +557,7 @@ async def process_diarization(self, task: ASRTask, debug_mode: bool) -> None: num_speakers=task.diarization.num_speakers, ) out = await time_and_tell_async( - lambda: self.remote_diarization( + self.remote_diarization( url=task.diarization.execution.url, data=data, ), @@ -601,7 +601,7 @@ def process_post_processing(self, task: ASRTask) -> None: if task.multi_channel: utterances, process_time = time_and_tell( - lambda: self.services[ + self.services[ "post_processing" ].multi_channel_speaker_mapping(task.transcription.result), func_name="multi_channel_speaker_mapping", @@ -611,7 +611,7 @@ def process_post_processing(self, task: ASRTask) -> None: else: formatted_segments, process_time = time_and_tell( - lambda: format_segments( + format_segments( transcription_output=task.transcription.result, ), func_name="format_segments", @@ -621,7 +621,7 @@ def process_post_processing(self, task: ASRTask) -> None: if task.diarization.execution is not None: utterances, process_time = time_and_tell( - lambda: self.services[ + self.services[ "post_processing" ].single_channel_speaker_mapping( transcript_segments=formatted_segments, @@ -636,7 +636,7 @@ def process_post_processing(self, task: ASRTask) -> None: utterances = formatted_segments final_utterances, process_time = time_and_tell( - lambda: self.services[ + self.services[ "post_processing" ].final_processing_before_returning( utterances=utterances, @@ -672,9 +672,10 @@ async def remote_transcription( async with session.post( url=f"{url}/api/v1/transcribe", data=data.model_dump_json(), + headers={"Content-Type": "application/json"}, ) as response: if response.status != 200: - raise Exception(response.detail) + raise Exception(response.status) else: return TranscriptionOutput(**await response.json()) @@ -688,9 +689,11 @@ async def remote_diarization( async with session.post( url=f"{url}/api/v1/diarize", data=data.model_dump_json(), + headers={"Content-Type": "application/json"}, ) as response: if response.status != 200: - raise Exception(response.detail) + r = await response.json() + raise Exception(r["detail"]) else: return DiarizationOutput(**await response.json()) diff --git a/src/wordcab_transcribe/utils.py b/src/wordcab_transcribe/utils.py index 07fa56b..43a0c77 100644 --- a/src/wordcab_transcribe/utils.py +++ b/src/wordcab_transcribe/utils.py @@ -42,6 +42,7 @@ Timestamps, TranscriptionOutput, Utterance, + Word, ) @@ -424,7 +425,15 @@ def format_segments(transcription_output: TranscriptionOutput) -> List[Utterance text=segment.text, start=segment.start, end=segment.end, - words=segment.words, + words=[ + Word( + word=word.word, + start=word.start, + end=word.end, + probability=word.probability, + ) + for word in segment.words + ] ) for segment in transcription_output.segments ]