diff --git a/src/wordcab_transcribe/models.py b/src/wordcab_transcribe/models.py index 86a3e93..8203a04 100644 --- a/src/wordcab_transcribe/models.py +++ b/src/wordcab_transcribe/models.py @@ -53,6 +53,16 @@ class Word(BaseModel): probability: float +class MultiChannelSegment(NamedTuple): + """Multi-channel segment model for the API.""" + + start: float + end: float + text: str + words: List[Word] + speaker: int + + class Utterance(BaseModel): """Utterance model for the API.""" @@ -513,6 +523,12 @@ class DiarizationRequest(BaseModel): num_speakers: int +class MultiChannelTranscriptionOutput(BaseModel): + """Multi-channel transcription output model for the API.""" + + segments: List[MultiChannelSegment] + + class TranscriptionOutput(BaseModel): """Transcription output model for the API.""" diff --git a/src/wordcab_transcribe/services/post_processing_service.py b/src/wordcab_transcribe/services/post_processing_service.py index 23011be..420e8c2 100644 --- a/src/wordcab_transcribe/services/post_processing_service.py +++ b/src/wordcab_transcribe/services/post_processing_service.py @@ -24,6 +24,7 @@ from wordcab_transcribe.models import ( DiarizationOutput, DiarizationSegment, + MultiChannelTranscriptionOutput, Timestamps, TranscriptionOutput, Utterance, @@ -75,13 +76,13 @@ def single_channel_speaker_mapping( return utterances def multi_channel_speaker_mapping( - self, multi_channel_segments: List[TranscriptionOutput] + self, multi_channel_segments: List[MultiChannelTranscriptionOutput] ) -> TranscriptionOutput: """ Run the multi-channel post-processing functions on the inputs by merging the segments based on the timestamps. Args: - multi_channel_segments (List[TranscriptionOutput]): + multi_channel_segments (List[MultiChannelTranscriptionOutput]): List of segments from multi speakers. Returns: @@ -93,7 +94,7 @@ def multi_channel_speaker_mapping( for segment in output.segments for word in segment.words ] - words_with_speaker_mapping.sort(key=lambda _, word: word.start) + words_with_speaker_mapping.sort(key=lambda x: x[1].start) utterances: List[Utterance] = self.reconstruct_multi_channel_utterances( words_with_speaker_mapping diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index eba00f4..4a0fa9b 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -26,7 +26,12 @@ from loguru import logger from tensorshare import Backend, TensorShare -from wordcab_transcribe.models import TranscriptionOutput +from wordcab_transcribe.models import ( + MultiChannelSegment, + MultiChannelTranscriptionOutput, + TranscriptionOutput, + Word, +) class FasterWhisperModel(NamedTuple): @@ -312,7 +317,7 @@ def multi_channel( no_speech_threshold: float = 0.6, condition_on_previous_text: bool = False, prompt: Optional[str] = None, - ) -> TranscriptionOutput: + ) -> MultiChannelTranscriptionOutput: """ Transcribe an audio file using the faster-whisper original pipeline. @@ -342,7 +347,7 @@ def multi_channel( prompt (Optional[str]): Initial prompt to use for the generation. Returns: - TranscriptionOutput: Transcription output. + MultiChannelTranscriptionOutput: Multi-channel transcription segments in a list. """ if isinstance(audio, torch.Tensor): _audio = audio.numpy() @@ -374,27 +379,13 @@ def multi_channel( ) for segment in segments: - segment_dict: dict = { - "start": None, - "end": None, - "text": segment.text, - "words": [], - "speaker": speaker_id, - } - - for word in segment.words: - segment_dict["words"].append( - { - "start": word.start, - "end": word.end, - "word": word.word, - "score": word.probability, - } - ) - - segment_dict["start"] = segment_dict["words"][0]["start"] - segment_dict["end"] = segment_dict["words"][-1]["end"] - - final_segments.append(segment_dict) + _segment = MultiChannelSegment( + start=segment.start, + end=segment.end, + text=segment.text, + words=[Word(**word._asdict()) for word in segment.words], + speaker=speaker_id, + ) + final_segments.append(_segment) - return TranscriptionOutput(segments=final_segments) + return MultiChannelTranscriptionOutput(segments=final_segments)