From 9c7c5130f200b94fa9a5367e2f1c5d10e17206a4 Mon Sep 17 00:00:00 2001 From: chainyo Date: Thu, 5 Oct 2023 12:44:38 +0000 Subject: [PATCH 1/2] fix dual_channel bug --- src/wordcab_transcribe/services/transcribe_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index eba00f4..c8f5175 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -388,7 +388,7 @@ def multi_channel( "start": word.start, "end": word.end, "word": word.word, - "score": word.probability, + "probability": word.probability, } ) From fbc8e4cb791782523b1ac4f3f5698d7a2bd4090f Mon Sep 17 00:00:00 2001 From: chainyo Date: Thu, 5 Oct 2023 13:56:01 +0000 Subject: [PATCH 2/2] fix output multi-channel --- src/wordcab_transcribe/models.py | 16 +++++++ .../services/post_processing_service.py | 7 +-- .../services/transcribe_service.py | 43 ++++++++----------- 3 files changed, 37 insertions(+), 29 deletions(-) diff --git a/src/wordcab_transcribe/models.py b/src/wordcab_transcribe/models.py index 86a3e93..8203a04 100644 --- a/src/wordcab_transcribe/models.py +++ b/src/wordcab_transcribe/models.py @@ -53,6 +53,16 @@ class Word(BaseModel): probability: float +class MultiChannelSegment(NamedTuple): + """Multi-channel segment model for the API.""" + + start: float + end: float + text: str + words: List[Word] + speaker: int + + class Utterance(BaseModel): """Utterance model for the API.""" @@ -513,6 +523,12 @@ class DiarizationRequest(BaseModel): num_speakers: int +class MultiChannelTranscriptionOutput(BaseModel): + """Multi-channel transcription output model for the API.""" + + segments: List[MultiChannelSegment] + + class TranscriptionOutput(BaseModel): """Transcription output model for the API.""" diff --git a/src/wordcab_transcribe/services/post_processing_service.py b/src/wordcab_transcribe/services/post_processing_service.py index 23011be..420e8c2 100644 --- a/src/wordcab_transcribe/services/post_processing_service.py +++ b/src/wordcab_transcribe/services/post_processing_service.py @@ -24,6 +24,7 @@ from wordcab_transcribe.models import ( DiarizationOutput, DiarizationSegment, + MultiChannelTranscriptionOutput, Timestamps, TranscriptionOutput, Utterance, @@ -75,13 +76,13 @@ def single_channel_speaker_mapping( return utterances def multi_channel_speaker_mapping( - self, multi_channel_segments: List[TranscriptionOutput] + self, multi_channel_segments: List[MultiChannelTranscriptionOutput] ) -> TranscriptionOutput: """ Run the multi-channel post-processing functions on the inputs by merging the segments based on the timestamps. Args: - multi_channel_segments (List[TranscriptionOutput]): + multi_channel_segments (List[MultiChannelTranscriptionOutput]): List of segments from multi speakers. Returns: @@ -93,7 +94,7 @@ def multi_channel_speaker_mapping( for segment in output.segments for word in segment.words ] - words_with_speaker_mapping.sort(key=lambda _, word: word.start) + words_with_speaker_mapping.sort(key=lambda x: x[1].start) utterances: List[Utterance] = self.reconstruct_multi_channel_utterances( words_with_speaker_mapping diff --git a/src/wordcab_transcribe/services/transcribe_service.py b/src/wordcab_transcribe/services/transcribe_service.py index c8f5175..4a0fa9b 100644 --- a/src/wordcab_transcribe/services/transcribe_service.py +++ b/src/wordcab_transcribe/services/transcribe_service.py @@ -26,7 +26,12 @@ from loguru import logger from tensorshare import Backend, TensorShare -from wordcab_transcribe.models import TranscriptionOutput +from wordcab_transcribe.models import ( + MultiChannelSegment, + MultiChannelTranscriptionOutput, + TranscriptionOutput, + Word, +) class FasterWhisperModel(NamedTuple): @@ -312,7 +317,7 @@ def multi_channel( no_speech_threshold: float = 0.6, condition_on_previous_text: bool = False, prompt: Optional[str] = None, - ) -> TranscriptionOutput: + ) -> MultiChannelTranscriptionOutput: """ Transcribe an audio file using the faster-whisper original pipeline. @@ -342,7 +347,7 @@ def multi_channel( prompt (Optional[str]): Initial prompt to use for the generation. Returns: - TranscriptionOutput: Transcription output. + MultiChannelTranscriptionOutput: Multi-channel transcription segments in a list. """ if isinstance(audio, torch.Tensor): _audio = audio.numpy() @@ -374,27 +379,13 @@ def multi_channel( ) for segment in segments: - segment_dict: dict = { - "start": None, - "end": None, - "text": segment.text, - "words": [], - "speaker": speaker_id, - } - - for word in segment.words: - segment_dict["words"].append( - { - "start": word.start, - "end": word.end, - "word": word.word, - "probability": word.probability, - } - ) - - segment_dict["start"] = segment_dict["words"][0]["start"] - segment_dict["end"] = segment_dict["words"][-1]["end"] - - final_segments.append(segment_dict) + _segment = MultiChannelSegment( + start=segment.start, + end=segment.end, + text=segment.text, + words=[Word(**word._asdict()) for word in segment.words], + speaker=speaker_id, + ) + final_segments.append(_segment) - return TranscriptionOutput(segments=final_segments) + return MultiChannelTranscriptionOutput(segments=final_segments)