From 4422e79b9f00f1ab000197288250a4ce008e8981 Mon Sep 17 00:00:00 2001 From: Aleks Date: Sat, 30 Mar 2024 20:19:41 -0400 Subject: [PATCH] Enabling longform diarization, skeleton for punct-based alignment --- .env | 7 ++++ src/wordcab_transcribe/config.py | 6 +++ .../services/asr_service.py | 11 ++++++ .../longform_diarization/configs/general.json | 2 +- .../longform_diarization/configs/meeting.json | 2 +- .../configs/telephonic.json | 2 +- .../longform_diarization/diarize_service.py | 9 +++++ .../services/post_processing_service.py | 38 +++++++++++++++++++ 8 files changed, 74 insertions(+), 3 deletions(-) diff --git a/.env b/.env index 8d78945..cfb7e37 100644 --- a/.env +++ b/.env @@ -67,6 +67,13 @@ SHIFT_LENGTHS="1.0,0.75,0.625,0.5,0.25" # "1.0, 1.0, 1.0, 1.0, 1.0". MULTISCALE_WEIGHTS="1.0,1.0,1.0,1.0,1.0" # +# --------------------------------------------------- POST-PROCESSING------------------------------------------------- # +# +# This parameter is used to control the punctuation-based alignment. If set to True, the predicted punctuation +# will be used to adjust speaker diarization. The default value is True, but note this comes with a performance +# tradeoff. +ENABLE_PUNCTUATION_BASED_ALIGNMENT=False +# # ---------------------------------------------- ASR TYPE CONFIGURATION ---------------------------------------------- # # # The asr_type parameter is used to control the type of ASR used. The available options are: `async` or `live`. diff --git a/src/wordcab_transcribe/config.py b/src/wordcab_transcribe/config.py index 3cf9c20..c78bc6f 100644 --- a/src/wordcab_transcribe/config.py +++ b/src/wordcab_transcribe/config.py @@ -54,6 +54,8 @@ class Settings: window_lengths: List[float] shift_lengths: List[float] multiscale_weights: List[float] + # Post-processing + enable_punctuation_based_alignment: bool # ASR type configuration asr_type: Literal["async", "live", "only_transcription", "only_diarization"] # Endpoint configuration @@ -325,6 +327,10 @@ def __post_init__(self): window_lengths=window_lengths, shift_lengths=shift_lengths, multiscale_weights=multiscale_weights, + # Post-processing + enable_punctuation_based_alignment=getenv( + "ENABLE_PUNCTUATION_BASED_ALIGNMENT", True + ), # ASR type asr_type=getenv("ASR_TYPE", "async"), # Endpoints configuration diff --git a/src/wordcab_transcribe/services/asr_service.py b/src/wordcab_transcribe/services/asr_service.py index c66f08e..3312013 100644 --- a/src/wordcab_transcribe/services/asr_service.py +++ b/src/wordcab_transcribe/services/asr_service.py @@ -764,6 +764,17 @@ def process_post_processing(self, task: ASRTask) -> None: else: utterances = formatted_segments + if settings.enable_punctuation_based_alignment: + utterances, process_time = time_and_tell( + self.local_services.post_processing.punctuation_based_alignment( + utterances=utterances, + speaker_timestamps=task.diarization.result, + ), + func_name="punctuation_based_alignment", + debug_mode=self.debug_mode, + ) + total_post_process_time += process_time + final_utterances, process_time = time_and_tell( self.local_services.post_processing.final_processing_before_returning( utterances=utterances, diff --git a/src/wordcab_transcribe/services/longform_diarization/configs/general.json b/src/wordcab_transcribe/services/longform_diarization/configs/general.json index 35a1194..0f30f89 100644 --- a/src/wordcab_transcribe/services/longform_diarization/configs/general.json +++ b/src/wordcab_transcribe/services/longform_diarization/configs/general.json @@ -1 +1 @@ -{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": "???", "out_dir": "???", "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.63, "shift_length_in_sec": 0.08, "smoothing": false, "overlap": 0.5, "onset": 0.5, "offset": 0.3, "pad_onset": 0.2, "pad_offset": 0.2, "min_duration_on": 0.5, "min_duration_off": 0.5, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [1.9, 1.2, 0.5], "shift_length_in_sec": [0.95, 0.6, 0.25], "multiscale_weights": [1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 10, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": null, "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": null, "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}} +{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": null, "out_dir": null, "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.63, "shift_length_in_sec": 0.08, "smoothing": false, "overlap": 0.5, "onset": 0.5, "offset": 0.3, "pad_onset": 0.2, "pad_offset": 0.2, "min_duration_on": 0.5, "min_duration_off": 0.5, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [1.9, 1.2, 0.5], "shift_length_in_sec": [0.95, 0.6, 0.25], "multiscale_weights": [1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 10, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": null, "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": null, "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}} diff --git a/src/wordcab_transcribe/services/longform_diarization/configs/meeting.json b/src/wordcab_transcribe/services/longform_diarization/configs/meeting.json index b845f80..440f9ef 100644 --- a/src/wordcab_transcribe/services/longform_diarization/configs/meeting.json +++ b/src/wordcab_transcribe/services/longform_diarization/configs/meeting.json @@ -1 +1 @@ -{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": "???", "out_dir": "???", "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.63, "shift_length_in_sec": 0.01, "smoothing": false, "overlap": 0.5, "onset": 0.9, "offset": 0.5, "pad_onset": 0, "pad_offset": 0, "min_duration_on": 0, "min_duration_off": 0.6, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [3.0, 2.5, 2.0, 1.5, 1.0, 0.5], "shift_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5, 0.25], "multiscale_weights": [1, 1, 1, 1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 30, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": null, "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": "stt_en_conformer_ctc_large", "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}} \ No newline at end of file +{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": null, "out_dir": null, "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.63, "shift_length_in_sec": 0.01, "smoothing": false, "overlap": 0.5, "onset": 0.9, "offset": 0.5, "pad_onset": 0, "pad_offset": 0, "min_duration_on": 0, "min_duration_off": 0.6, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [3.0, 2.5, 2.0, 1.5, 1.0, 0.5], "shift_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5, 0.25], "multiscale_weights": [1, 1, 1, 1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 30, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": null, "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": "stt_en_conformer_ctc_large", "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}} \ No newline at end of file diff --git a/src/wordcab_transcribe/services/longform_diarization/configs/telephonic.json b/src/wordcab_transcribe/services/longform_diarization/configs/telephonic.json index e2c50fc..c82c25e 100644 --- a/src/wordcab_transcribe/services/longform_diarization/configs/telephonic.json +++ b/src/wordcab_transcribe/services/longform_diarization/configs/telephonic.json @@ -1 +1 @@ -{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": "???", "out_dir": "???", "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.15, "shift_length_in_sec": 0.01, "smoothing": "median", "overlap": 0.5, "onset": 0.1, "offset": 0.1, "pad_onset": 0.1, "pad_offset": 0, "min_duration_on": 0, "min_duration_off": 0.2, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5], "shift_length_in_sec": [0.75, 0.625, 0.5, 0.375, 0.25], "multiscale_weights": [1, 1, 1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 30, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": "diar_msdd_telephonic", "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": "stt_en_conformer_ctc_large", "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}} \ No newline at end of file +{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": null, "out_dir": null, "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.15, "shift_length_in_sec": 0.01, "smoothing": "median", "overlap": 0.5, "onset": 0.1, "offset": 0.1, "pad_onset": 0.1, "pad_offset": 0, "min_duration_on": 0, "min_duration_off": 0.2, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5], "shift_length_in_sec": [0.75, 0.625, 0.5, 0.375, 0.25], "multiscale_weights": [1, 1, 1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 30, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": "diar_msdd_telephonic", "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": "stt_en_conformer_ctc_large", "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}} \ No newline at end of file diff --git a/src/wordcab_transcribe/services/longform_diarization/diarize_service.py b/src/wordcab_transcribe/services/longform_diarization/diarize_service.py index b80049d..1a28bd6 100644 --- a/src/wordcab_transcribe/services/longform_diarization/diarize_service.py +++ b/src/wordcab_transcribe/services/longform_diarization/diarize_service.py @@ -19,6 +19,7 @@ # and limitations under the License. """Longform diarization Service for audio files.""" +import json import tempfile from pathlib import Path from typing import List, Optional, Tuple, Union @@ -27,6 +28,7 @@ import torch import torchaudio from nemo.collections.asr.models import NeuralDiarizer +from omegaconf import OmegaConf from tensorshare import Backend, TensorShare from wordcab_transcribe.models import DiarizationOutput @@ -43,6 +45,7 @@ class LongFormDiarizeService: def __init__( self, device: str, + domain: str = "telephonic", ) -> None: """ Initialize the DiarizeService. @@ -53,9 +56,15 @@ def __init__( Returns: None """ + config_path = Path(__file__).parent / "configs" / f"{domain}.json" + with open(config_path, "r") as f: + general_conf = json.load(f) + + config = OmegaConf.create(general_conf) self.diarization_model = NeuralDiarizer.from_pretrained( model_name="diar_msdd_telephonic" ).to(device) + self.diarization_model._cfg = config # TODO: Ability to modify config def __call__( diff --git a/src/wordcab_transcribe/services/post_processing_service.py b/src/wordcab_transcribe/services/post_processing_service.py index 9ff43fa..3bc82d8 100644 --- a/src/wordcab_transcribe/services/post_processing_service.py +++ b/src/wordcab_transcribe/services/post_processing_service.py @@ -21,6 +21,7 @@ from typing import List, Tuple, Union +from wordcab_transcribe.config import settings from wordcab_transcribe.models import ( DiarizationOutput, DiarizationSegment, @@ -40,6 +41,13 @@ def __init__(self) -> None: """Initialize the PostProcessingService.""" self.sample_rate = 16000 + if settings.enable_punctuation_based_alignment: + from deepmultilingualpunctuation import PunctuationModel + + self.punct_model = PunctuationModel(model="kredor/punctuate-all") + else: + self.punct_model = None + def single_channel_speaker_mapping( self, transcript_segments: List[Utterance], @@ -379,6 +387,36 @@ def reconstruct_multi_channel_utterances( return [Utterance(**sentence) for sentence in sentences] + def punctuation_based_alignment( + self, + utterances: List[Utterance], + speaker_timestamps: DiarizationOutput, + ): + pass + # word_list = [] + # for utterance in utterances: + # for word in utterance.words: + # word_list.append(word.word) + # + # labled_words = self.punct_model.predict(word_list) + # + # ending_puncts = ".?!" + # model_puncts = ".,;:!?" + # + # def is_acronym(w): + # return re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", w) + # + # for ix, (word, labeled_tuple) in enumerate(zip(word_list, labled_words)): + # if ( + # word + # and labeled_tuple[1] in ending_puncts + # and (word[-1] not in model_puncts or is_acronym(word)) + # ): + # word += labeled_tuple[1] + # if word.endswith(".."): + # word = word.rstrip(".") + # word_dict["word"] = word + def final_processing_before_returning( self, utterances: List[Utterance],