From 1e1b40f21462fd2f2ea3f024f73d312c99f022bd Mon Sep 17 00:00:00 2001 From: Aleks Date: Sun, 8 Oct 2023 15:35:58 -0400 Subject: [PATCH] Adjust VAD speech padding for diarizer, add additional logic to speaker mapping function --- .../services/post_processing_service.py | 79 ++++++++++++++++--- .../services/vad_service.py | 2 +- 2 files changed, 68 insertions(+), 13 deletions(-) diff --git a/src/wordcab_transcribe/services/post_processing_service.py b/src/wordcab_transcribe/services/post_processing_service.py index 420e8c2..dd3cc81 100644 --- a/src/wordcab_transcribe/services/post_processing_service.py +++ b/src/wordcab_transcribe/services/post_processing_service.py @@ -119,7 +119,26 @@ def segments_speaker_mapping( Returns: List[dict]: List of sentences with speaker mapping. """ + + def _assign_speaker( + mapping: list, + seg_index: int, + split: bool, + current_speaker: str, + current_split_len: int, + ): + """Assign speaker to the segment.""" + if split and len(mapping) > 1: + last_split_len = len(mapping[seg_index - 1].text) + if last_split_len > current_split_len: + current_speaker = mapping[seg_index - 1].speaker + elif last_split_len < current_split_len: + mapping[seg_index - 1].speaker = current_speaker + return current_speaker + + threshold = 0.3 turn_idx = 0 + was_split = False _, end, speaker = speaker_timestamps[turn_idx] segment_index = 0 @@ -131,8 +150,10 @@ def segments_speaker_mapping( segment.end, segment.text, ) - - while segment_start > float(end) or abs(segment_start - float(end)) < 0.3: + while ( + segment_start > float(end) + or abs(segment_start - float(end)) < threshold + ): turn_idx += 1 turn_idx = min(turn_idx, len(speaker_timestamps) - 1) _, end, speaker = speaker_timestamps[turn_idx] @@ -140,50 +161,76 @@ def segments_speaker_mapping( end = segment_end break - if segment_end > float(end): + if segment_end > float(end) and abs(segment_end - float(end)) > threshold: words = segment.words - word_index = next( ( i for i, word in enumerate(words) - if word.start > float(end) or abs(word.start - float(end)) < 0.3 + if word.start > float(end) + or abs(word.start - float(end)) < threshold ), None, ) if word_index is not None: - _splitted_segment = segment_text.split() + _split_segment = segment_text.split() if word_index > 0: + text = " ".join(_split_segment[:word_index]) + speaker = _assign_speaker( + segment_speaker_mapping, + segment_index, + was_split, + speaker, + len(text), + ) + _segment_to_add = Utterance( start=words[0].start, end=words[word_index - 1].end, - text=" ".join(_splitted_segment[:word_index]), + text=text, speaker=speaker, words=words[:word_index], ) - else: + text = _split_segment[0] + speaker = _assign_speaker( + segment_speaker_mapping, + segment_index, + was_split, + speaker, + len(text), + ) + _segment_to_add = Utterance( start=words[0].start, end=words[0].end, - text=_splitted_segment[0], + text=_split_segment[0], speaker=speaker, words=words[:1], ) - segment_speaker_mapping.append(_segment_to_add) transcript_segments.insert( segment_index + 1, Utterance( start=words[word_index].start, end=segment_end, - text=" ".join(_splitted_segment[word_index:]), + text=" ".join(_split_segment[word_index:]), words=words[word_index:], ), ) + was_split = True else: + speaker = _assign_speaker( + segment_speaker_mapping, + segment_index, + was_split, + speaker, + len(segment_text), + ) + was_split = False + segment_speaker_mapping.append( Utterance( start=segment_start, @@ -194,6 +241,15 @@ def segments_speaker_mapping( ) ) else: + speaker = _assign_speaker( + segment_speaker_mapping, + segment_index, + was_split, + speaker, + len(segment_text), + ) + was_split = False + segment_speaker_mapping.append( Utterance( start=segment_start, @@ -203,7 +259,6 @@ def segments_speaker_mapping( words=segment.words, ) ) - segment_index += 1 return segment_speaker_mapping diff --git a/src/wordcab_transcribe/services/vad_service.py b/src/wordcab_transcribe/services/vad_service.py index ef34412..396ea11 100644 --- a/src/wordcab_transcribe/services/vad_service.py +++ b/src/wordcab_transcribe/services/vad_service.py @@ -38,7 +38,7 @@ def __init__(self) -> None: max_speech_duration_s=30, min_silence_duration_ms=100, window_size_samples=512, - speech_pad_ms=30, + speech_pad_ms=400, ) def __call__(