diff --git a/.env b/.env index 854151a..997a524 100644 --- a/.env +++ b/.env @@ -16,16 +16,6 @@ API_PREFIX="/api/v1" # Debug mode for FastAPI. It allows for hot reloading when code changes in development. DEBUG=True # -# ----------------------------------------------- BATCH CONFIGURATION ------------------------------------------------ # -# -# The batch_size parameter is used to control the number of audio files that are processed in parallel. -# If your server GPU has a lot of memory, you can increase this value to improve performance. -# For simplicity, we recommend leaving this value at 1, unless you are sure that your GPU has enough memory (> 40GB) -BATCH_SIZE=1 -# The max_wait parameter is used to control the maximum amount of time (in seconds) that the server will wait for -# processing the tasks in the queue, if not empty. It's useful only when the batch_size is greater than 1. -MAX_WAIT=0.1 -# # ----------------------------------------------- MODELS CONFIGURATION ----------------------------------------------- # # # ----------------------------------------------------- WHISPER ------------------------------------------------------ # @@ -47,20 +37,20 @@ WHISPER_MODEL="large-v2" COMPUTE_TYPE="float16" # The extra_languages parameter is used to control the languages that need an extra model to be loaded. # You can specify multiple languages separated by a comma. The available languages are: `he` (Hebrew). -EXTRA_LANGUAGES="" -# -# --------------------------------------------------- NVIDIA NEMO ---------------------------------------------------- # -# -# The nemo_domain_type define the configuration file used by the model for diarization. The available options are: -# `general`, `meeting` and `telephonic`. The default value is `telephonic`. If you choose another type, you will need -# to provide a custom model -NEMO_DOMAIN_TYPE="telephonic" -# The nemo_storage_path parameter is used to control the path where the NeuralDiarizer from the NeMo toolkit will -# store the diarization models. -NEMO_STORAGE_PATH="nemo_storage" -# The nemo_output_path parameter is used to control the path where the NeuralDiarizer from the NeMo toolkit will -# store the diarization outputs. -NEMO_OUTPUT_PATH="nemo_outputs" +EXTRA_LANGUAGES= +# +# --------------------------------------------------- DIARIZATION ---------------------------------------------------- # +# +# In a MSDD (Multiscale Diarization Decoder) model, the diarization model is trained on multiple window lengths. +# The window_lengths are specified in seconds, and separated by a comma. If not specified, the default value will +# be "1.5, 1.25, 1.0, 0.75, 0.5". +WINDOW_LENGTHS="1.5,1.25,1.0,0.75,0.5" +# The shift_lengths are specified in seconds, and separated by a comma. If not specified, the default value will +# be "0.75, 0.625, 0.5, 0.375, 0.25". +SHIFT_LENGTHS="0.75,0.625,0.5,0.375,0.25" +# The multiscale_weights are float values separated by a comma. If not specified, the default value will be +# "1.0, 1.0, 1.0, 1.0, 1.0". +MULTISCALE_WEIGHTS="1.0,1.0,1.0,1.0,1.0" # # ---------------------------------------------- ASR TYPE CONFIGURATION ---------------------------------------------- # # diff --git a/.gitignore b/.gitignore index 532f4c3..912bfe9 100644 --- a/.gitignore +++ b/.gitignore @@ -10,9 +10,10 @@ __pycache__/ /data/ /dist/ /docs/_build/ +/infer_out_dir/ /src/*.egg-info/ +async_bench.py test.ipynb test.py -async_bench.py whisper_model whisper_model_he diff --git a/Dockerfile b/Dockerfile index 5699e0b..e769468 100644 --- a/Dockerfile +++ b/Dockerfile @@ -31,7 +31,6 @@ COPY ./poetry.lock ./pyproject.toml ./ RUN poetry install --only main COPY ./wordcab_transcribe /app/wordcab_transcribe -COPY ./config /app/config COPY ./.env /app/.env WORKDIR /app diff --git a/config/nemo/diar_infer_general.yaml b/config/nemo/diar_infer_general.yaml deleted file mode 100644 index 941622b..0000000 --- a/config/nemo/diar_infer_general.yaml +++ /dev/null @@ -1,91 +0,0 @@ -# This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. -# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file. -# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. -# The configurations in this YAML file is optimized to show balanced performances on various types of domain. VAD is optimized on multilingual ASR datasets and diarizer is optimized on DIHARD3 development set. -# An example line in an input manifest file (`.json` format): -# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"} -name: &name "ClusterDiarizer" - -num_workers: 1 -sample_rate: 16000 -batch_size: 64 -device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu) -verbose: True # enable additional logging - -diarizer: - manifest_filepath: ??? - out_dir: ??? - oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps - collar: 0.25 # Collar value for scoring - ignore_overlap: True # Consider or ignore overlap segments while scoring - - vad: - model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name - external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set - - parameters: # Tuned by detection error rate (false alarm + miss) on multilingual ASR evaluation datasets - window_length_in_sec: 0.63 # Window length in sec for VAD context input - shift_length_in_sec: 0.08 # Shift length in sec for generate frame level VAD prediction - smoothing: False # False or type of smoothing method (eg: median) - overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter - onset: 0.5 # Onset threshold for detecting the beginning and end of a speech - offset: 0.3 # Offset threshold for detecting the end of a speech - pad_onset: 0.2 # Adding durations before each speech segment - pad_offset: 0.2 # Adding durations after each speech segment - min_duration_on: 0.5 # Threshold for small non_speech deletion - min_duration_off: 0.5 # Threshold for short speech segment deletion - filter_speech_first: True - - speaker_embeddings: - model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) - parameters: - window_length_in_sec: [1.9, 1.2, 0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] - shift_length_in_sec: [0.95, 0.6, 0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] - multiscale_weights: [1, 1, 1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] - save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`. - - clustering: - parameters: - oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. - max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. - enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. - max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. - sparse_search_volume: 10 # The higher the number, the more values will be examined with more time. - maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. - - msdd_model: - model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) - parameters: - use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used. - infer_batch_size: 25 # Batch size for MSDD inference. - sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps. - seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False. - split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference. - diar_window_length: 50 # The length of split short sequence when split_infer is True. - overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated. - - asr: - model_path: null # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes. - parameters: - asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference. - asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD. - asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null. - decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model. - word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. - word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'. - fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature. - colored_text: False # If True, use colored text to distinguish speakers in the output transcript. - print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript. - break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars) - - ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode) - pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file. - beam_width: 32 - alpha: 0.5 - beta: 2.5 - - realigning_lm_parameters: # Experimental feature - arpa_language_model: null # Provide a KenLM language model in .arpa format. - min_number_of_words: 3 # Min number of words for the left context. - max_number_of_words: 10 # Max number of words for the right context. - logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. diff --git a/config/nemo/diar_infer_meeting.yaml b/config/nemo/diar_infer_meeting.yaml deleted file mode 100644 index e03cd18..0000000 --- a/config/nemo/diar_infer_meeting.yaml +++ /dev/null @@ -1,91 +0,0 @@ -# This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. -# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file. -# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. -# The configurations in this YAML file is suitable for 3~5 speakers participating in a meeting and may not show the best performance on other types of dialogues. -# An example line in an input manifest file (`.json` format): -# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"} -name: &name "ClusterDiarizer" - -num_workers: 1 -sample_rate: 16000 -batch_size: 64 -device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu) -verbose: True # enable additional logging - -diarizer: - manifest_filepath: ??? - out_dir: ??? - oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps - collar: 0.25 # Collar value for scoring - ignore_overlap: True # Consider or ignore overlap segments while scoring - - vad: - model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name - external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set - - parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) - window_length_in_sec: 0.63 # Window length in sec for VAD context input - shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction - smoothing: False # False or type of smoothing method (eg: median) - overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter - onset: 0.9 # Onset threshold for detecting the beginning and end of a speech - offset: 0.5 # Offset threshold for detecting the end of a speech - pad_onset: 0 # Adding durations before each speech segment - pad_offset: 0 # Adding durations after each speech segment - min_duration_on: 0 # Threshold for small non_speech deletion - min_duration_off: 0.6 # Threshold for short speech segment deletion - filter_speech_first: True - - speaker_embeddings: - model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) - parameters: - window_length_in_sec: [3.0, 2.5, 2.0, 1.5, 1.0, 0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] - shift_length_in_sec: [1.5, 1.25, 1.0, 0.75, 0.5, 0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] - multiscale_weights: [1, 1, 1, 1, 1, 1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] - save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`. - - clustering: - parameters: - oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. - max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. - enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. - max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. - sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. - maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. - - msdd_model: - model_path: null # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) - parameters: - use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used. - infer_batch_size: 25 # Batch size for MSDD inference. - sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps. - seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False. - split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference. - diar_window_length: 50 # The length of split short sequence when split_infer is True. - overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated. - - asr: - model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes. - parameters: - asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference. - asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD. - asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null. - decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model. - word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. - word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'. - fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature. - colored_text: False # If True, use colored text to distinguish speakers in the output transcript. - print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript. - break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars) - - ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode) - pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file. - beam_width: 32 - alpha: 0.5 - beta: 2.5 - - realigning_lm_parameters: # Experimental feature - arpa_language_model: null # Provide a KenLM language model in .arpa format. - min_number_of_words: 3 # Min number of words for the left context. - max_number_of_words: 10 # Max number of words for the right context. - logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. diff --git a/config/nemo/diar_infer_telephonic.yaml b/config/nemo/diar_infer_telephonic.yaml deleted file mode 100644 index 1c2358c..0000000 --- a/config/nemo/diar_infer_telephonic.yaml +++ /dev/null @@ -1,91 +0,0 @@ -# This YAML file is created for all types of offline speaker diarization inference tasks in `/example/speaker_tasks/diarization` folder. -# The inference parameters for VAD, speaker embedding extractor, clustering module, MSDD module, ASR decoder are all included in this YAML file. -# All the keys under `diarizer` key (`vad`, `speaker_embeddings`, `clustering`, `msdd_model`, `asr`) can be selectively used for its own purpose and also can be ignored if the module is not used. -# The configurations in this YAML file is suitable for telephone recordings involving 2~8 speakers in a session and may not show the best performance on the other types of acoustic conditions or dialogues. -# An example line in an input manifest file (`.json` format): -# {"audio_filepath": "/path/to/audio_file", "offset": 0, "duration": null, "label": "infer", "text": "-", "num_speakers": null, "rttm_filepath": "/path/to/rttm/file", "uem_filepath": "/path/to/uem/file"} -name: &name "ClusterDiarizer" - -num_workers: 1 -sample_rate: 16000 -batch_size: 64 -device: null # can specify a specific device, i.e: cuda:1 (default cuda if cuda available, else cpu) -verbose: False # enable additional logging - -diarizer: - manifest_filepath: ??? - out_dir: ??? - oracle_vad: False # If True, uses RTTM files provided in the manifest file to get speech activity (VAD) timestamps - collar: 0.25 # Collar value for scoring - ignore_overlap: True # Consider or ignore overlap segments while scoring - - vad: - model_path: vad_multilingual_marblenet # .nemo local model path or pretrained VAD model name - external_vad_manifest: null # This option is provided to use external vad and provide its speech activity labels for speaker embeddings extraction. Only one of model_path or external_vad_manifest should be set - - parameters: # Tuned parameters for CH109 (using the 11 multi-speaker sessions as dev set) - window_length_in_sec: 0.15 # Window length in sec for VAD context input - shift_length_in_sec: 0.01 # Shift length in sec for generate frame level VAD prediction - smoothing: "median" # False or type of smoothing method (eg: median) - overlap: 0.5 # Overlap ratio for overlapped mean/median smoothing filter - onset: 0.1 # Onset threshold for detecting the beginning and end of a speech - offset: 0.1 # Offset threshold for detecting the end of a speech - pad_onset: 0.1 # Adding durations before each speech segment - pad_offset: 0 # Adding durations after each speech segment - min_duration_on: 0 # Threshold for small non_speech deletion - min_duration_off: 0.2 # Threshold for short speech segment deletion - filter_speech_first: True - - speaker_embeddings: - model_path: titanet_large # .nemo local model path or pretrained model name (titanet_large, ecapa_tdnn or speakerverification_speakernet) - parameters: - window_length_in_sec: [1.5, 1.25, 1.0, 0.75, 0.5] # Window length(s) in sec (floating-point number). either a number or a list. ex) 1.5 or [1.5,1.0,0.5] - shift_length_in_sec: [0.75, 0.625, 0.5, 0.375, 0.25] # Shift length(s) in sec (floating-point number). either a number or a list. ex) 0.75 or [0.75,0.5,0.25] - multiscale_weights: [1, 1, 1, 1, 1] # Weight for each scale. should be null (for single scale) or a list matched with window/shift scale count. ex) [0.33,0.33,0.33] - save_embeddings: True # If True, save speaker embeddings in pickle format. This should be True if clustering result is used for other models, such as `msdd_model`. - - clustering: - parameters: - oracle_num_speakers: False # If True, use num of speakers value provided in manifest file. - max_num_speakers: 8 # Max number of speakers for each recording. If an oracle number of speakers is passed, this value is ignored. - enhanced_count_thres: 80 # If the number of segments is lower than this number, enhanced speaker counting is activated. - max_rp_threshold: 0.25 # Determines the range of p-value search: 0 < p <= max_rp_threshold. - sparse_search_volume: 30 # The higher the number, the more values will be examined with more time. - maj_vote_spk_count: False # If True, take a majority vote on multiple p-values to estimate the number of speakers. - - msdd_model: - model_path: diar_msdd_telephonic # .nemo local model path or pretrained model name for multiscale diarization decoder (MSDD) - parameters: - use_speaker_model_from_ckpt: True # If True, use speaker embedding model in checkpoint. If False, the provided speaker embedding model in config will be used. - infer_batch_size: 25 # Batch size for MSDD inference. - sigmoid_threshold: [0.7] # Sigmoid threshold for generating binarized speaker labels. The smaller the more generous on detecting overlaps. - seq_eval_mode: False # If True, use oracle number of speaker and evaluate F1 score for the given speaker sequences. Default is False. - split_infer: True # If True, break the input audio clip to short sequences and calculate cluster average embeddings for inference. - diar_window_length: 50 # The length of split short sequence when split_infer is True. - overlap_infer_spk_limit: 5 # If the estimated number of speakers are larger than this number, overlap speech is not estimated. - - asr: - model_path: stt_en_conformer_ctc_large # Provide NGC cloud ASR model name. stt_en_conformer_ctc_* models are recommended for diarization purposes. - parameters: - asr_based_vad: False # if True, speech segmentation for diarization is based on word-timestamps from ASR inference. - asr_based_vad_threshold: 1.0 # Threshold (in sec) that caps the gap between two words when generating VAD timestamps using ASR based VAD. - asr_batch_size: null # Batch size can be dependent on each ASR model. Default batch sizes are applied if set to null. - decoder_delay_in_sec: null # Native decoder delay. null is recommended to use the default values for each ASR model. - word_ts_anchor_offset: null # Offset to set a reference point from the start of the word. Recommended range of values is [-0.05 0.2]. - word_ts_anchor_pos: "start" # Select which part of the word timestamp we want to use. The options are: 'start', 'end', 'mid'. - fix_word_ts_with_VAD: False # Fix the word timestamp using VAD output. You must provide a VAD model to use this feature. - colored_text: False # If True, use colored text to distinguish speakers in the output transcript. - print_time: True # If True, the start and end time of each speaker turn is printed in the output transcript. - break_lines: False # If True, the output transcript breaks the line to fix the line width (default is 90 chars) - - ctc_decoder_parameters: # Optional beam search decoder (pyctcdecode) - pretrained_language_model: null # KenLM model file: .arpa model file or .bin binary file. - beam_width: 32 - alpha: 0.5 - beta: 2.5 - - realigning_lm_parameters: # Experimental feature - arpa_language_model: null # Provide a KenLM language model in .arpa format. - min_number_of_words: 3 # Min number of words for the left context. - max_number_of_words: 10 # Max number of words for the right context. - logprob_diff_threshold: 1.2 # The threshold for the difference between two log probability values from two hypotheses. diff --git a/split_diarization.ipynb b/split_diarization.ipynb new file mode 100644 index 0000000..b1e1d34 --- /dev/null +++ b/split_diarization.ipynb @@ -0,0 +1,1832 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Base process" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-25 09:06:34 cloud:58] Found existing object /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.\n", + "[NeMo I 2023-07-25 09:06:34 cloud:64] Re-using file from: /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo\n", + "[NeMo I 2023-07-25 09:06:34 common:913] Instantiating model from pre-trained checkpoint\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2023-07-25 09:06:35 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath: /manifests/combined_fisher_swbd_voxceleb12_librispeech/train.json\n", + " sample_rate: 16000\n", + " labels: null\n", + " batch_size: 64\n", + " shuffle: true\n", + " is_tarred: false\n", + " tarred_audio_filepaths: null\n", + " tarred_shard_strategy: scatter\n", + " augmentor:\n", + " noise:\n", + " manifest_path: /manifests/noise/rir_noise_manifest.json\n", + " prob: 0.5\n", + " min_snr_db: 0\n", + " max_snr_db: 15\n", + " speed:\n", + " prob: 0.5\n", + " sr: 16000\n", + " resample_type: kaiser_fast\n", + " min_speed_rate: 0.95\n", + " max_speed_rate: 1.05\n", + " num_workers: 15\n", + " pin_memory: true\n", + " \n", + "[NeMo W 2023-07-25 09:06:35 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath: /manifests/combined_fisher_swbd_voxceleb12_librispeech/dev.json\n", + " sample_rate: 16000\n", + " labels: null\n", + " batch_size: 128\n", + " shuffle: false\n", + " num_workers: 15\n", + " pin_memory: true\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-25 09:06:35 features:291] PADDING: 16\n", + "[NeMo I 2023-07-25 09:06:35 save_restore_connector:249] Model EncDecSpeakerLabelModel was successfully restored from /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.\n", + "[NeMo I 2023-07-25 09:06:35 clustering_diarizer:127] Loading pretrained vad_multilingual_marblenet model from NGC\n", + "[NeMo I 2023-07-25 09:06:35 cloud:58] Found existing object /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.\n", + "[NeMo I 2023-07-25 09:06:35 cloud:64] Re-using file from: /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo\n", + "[NeMo I 2023-07-25 09:06:35 common:913] Instantiating model from pre-trained checkpoint\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2023-07-25 09:06:35 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath: /manifests/ami_train_0.63.json,/manifests/freesound_background_train.json,/manifests/freesound_laughter_train.json,/manifests/fisher_2004_background.json,/manifests/fisher_2004_speech_sampled.json,/manifests/google_train_manifest.json,/manifests/icsi_all_0.63.json,/manifests/musan_freesound_train.json,/manifests/musan_music_train.json,/manifests/musan_soundbible_train.json,/manifests/mandarin_train_sample.json,/manifests/german_train_sample.json,/manifests/spanish_train_sample.json,/manifests/french_train_sample.json,/manifests/russian_train_sample.json\n", + " sample_rate: 16000\n", + " labels:\n", + " - background\n", + " - speech\n", + " batch_size: 256\n", + " shuffle: true\n", + " is_tarred: false\n", + " tarred_audio_filepaths: null\n", + " tarred_shard_strategy: scatter\n", + " augmentor:\n", + " shift:\n", + " prob: 0.5\n", + " min_shift_ms: -10.0\n", + " max_shift_ms: 10.0\n", + " white_noise:\n", + " prob: 0.5\n", + " min_level: -90\n", + " max_level: -46\n", + " norm: true\n", + " noise:\n", + " prob: 0.5\n", + " manifest_path: /manifests/noise_0_1_musan_fs.json\n", + " min_snr_db: 0\n", + " max_snr_db: 30\n", + " max_gain_db: 300.0\n", + " norm: true\n", + " gain:\n", + " prob: 0.5\n", + " min_gain_dbfs: -10.0\n", + " max_gain_dbfs: 10.0\n", + " norm: true\n", + " num_workers: 16\n", + " pin_memory: true\n", + " \n", + "[NeMo W 2023-07-25 09:06:35 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath: /manifests/ami_dev_0.63.json,/manifests/freesound_background_dev.json,/manifests/freesound_laughter_dev.json,/manifests/ch120_moved_0.63.json,/manifests/fisher_2005_500_speech_sampled.json,/manifests/google_dev_manifest.json,/manifests/musan_music_dev.json,/manifests/mandarin_dev.json,/manifests/german_dev.json,/manifests/spanish_dev.json,/manifests/french_dev.json,/manifests/russian_dev.json\n", + " sample_rate: 16000\n", + " labels:\n", + " - background\n", + " - speech\n", + " batch_size: 256\n", + " shuffle: false\n", + " val_loss_idx: 0\n", + " num_workers: 16\n", + " pin_memory: true\n", + " \n", + "[NeMo W 2023-07-25 09:06:35 modelPT:174] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).\n", + " Test config : \n", + " manifest_filepath: null\n", + " sample_rate: 16000\n", + " labels:\n", + " - background\n", + " - speech\n", + " batch_size: 128\n", + " shuffle: false\n", + " test_loss_idx: 0\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-25 09:06:35 features:291] PADDING: 16\n", + "[NeMo I 2023-07-25 09:06:35 save_restore_connector:249] Model EncDecClassificationModel was successfully restored from /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/vad_multilingual_marblenet/670f425c7f186060b7a7268ba6dfacb2/vad_multilingual_marblenet.nemo.\n" + ] + } + ], + "source": [ + "import json\n", + "\n", + "from nemo.collections.asr.models.label_models import EncDecSpeakerLabelModel\n", + "from nemo.collections.asr.models.msdd_models import ClusteringDiarizer\n", + "\n", + "from omegaconf import OmegaConf\n", + "\n", + "\n", + "# Load yaml file\n", + "with open(\"./config/nemo/diar_infer_telephonic.yaml\") as f:\n", + " cfg = OmegaConf.load(f)\n", + "\n", + "meta = {\n", + " \"audio_filepath\": \"mono_file.wav\",\n", + " \"offset\": 0,\n", + " \"duration\": None,\n", + " \"label\": \"infer\",\n", + " \"text\": \"-\",\n", + " \"rttm_filepath\": None,\n", + " \"uem_filepath\": None,\n", + "}\n", + "\n", + "manifest_path = \"infer_manifest.json\"\n", + "with open(\"infer_manifest.json\", \"w\") as fp:\n", + " json.dump(meta, fp)\n", + " fp.write(\"\\n\")\n", + "\n", + "cfg.diarizer.manifest_filepath = str(manifest_path)\n", + "cfg.diarizer.out_dir = \"infer_out_dir\"\n", + "\n", + "speaker_model = EncDecSpeakerLabelModel.from_pretrained(\n", + " model_name=\"titanet_large\", map_location=None\n", + ")\n", + "speaker_params = {\n", + " \"window_length_in_sec\": [1.5, 1.25, 1.0, 0.75, 0.5],\n", + " \"shift_length_in_sec\": [0.75, 0.625, 0.5, 0.375, 0.25],\n", + " \"multiscale_weights\": [1, 1, 1, 1, 1],\n", + " \"save_embeddings\": True,\n", + "}\n", + "cluster_params = {\n", + " \"oracle_num_speakers\": False,\n", + " \"max_num_speakers\": 8,\n", + " \"enhanced_count_thres\": 80,\n", + " \"max_rp_threshold\": 0.25,\n", + " \"sparse_search_volume\": 30,\n", + " \"maj_vote_spk_count\": False,\n", + "}\n", + "\n", + "clus_diar_model = ClusteringDiarizer(cfg=cfg, speaker_model=speaker_model)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import librosa\n", + "import soundfile as sf\n", + "\n", + "filepath = \"./mono_file.mp3\"\n", + "waveform, sample_rate = librosa.load(filepath, sr=None)\n", + "sf.write(\"./mono_file.wav\", waveform, sample_rate, \"PCM_16\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2023-07-25 09:37:44 clustering_diarizer:411] Deleting previous clustering diarizer outputs.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-25 09:37:44 speaker_utils:93] Number of files to diarize: 1\n", + "[NeMo I 2023-07-25 09:37:44 clustering_diarizer:309] Split long audio file to avoid CUDA memory issue\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "splitting manifest: 100%|██████████| 1/1 [00:00<00:00, 11.75it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-25 09:37:44 vad_utils:101] The prepared manifest file exists. Overwriting!\n", + "[NeMo I 2023-07-25 09:37:44 classification_models:268] Perform streaming frame-level VAD\n", + "[NeMo I 2023-07-25 09:37:44 collections:298] Filtered duration for loading collection is 0.00 hours.\n", + "[NeMo I 2023-07-25 09:37:44 collections:299] Dataset loaded with 3 items, total duration of 0.04 hours.\n", + "[NeMo I 2023-07-25 09:37:44 collections:301] # 3 files loaded accounting to # 1 labels\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-25 09:37:46 clustering_diarizer:250] Generating predictions with overlapping input segments\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " " + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-25 09:37:47 clustering_diarizer:262] Converting frame level prediction to speech/no-speech segment in start and end times format.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "creating speech segments: 100%|██████████| 1/1 [00:00<00:00, 3.86it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-25 09:37:47 clustering_diarizer:287] Subsegmentation for embedding extraction: scale0, infer_out_dir/speaker_outputs/subsegments_scale0.json\n", + "[NeMo I 2023-07-25 09:37:47 clustering_diarizer:343] Extracting embeddings for Diarization\n", + "[NeMo I 2023-07-25 09:37:47 collections:298] Filtered duration for loading collection is 0.00 hours.\n", + "[NeMo I 2023-07-25 09:37:47 collections:299] Dataset loaded with 104 items, total duration of 0.04 hours.\n", + "[NeMo I 2023-07-25 09:37:47 collections:301] # 104 files loaded accounting to # 1 labels\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-25 09:37:48 clustering_diarizer:389] Saved embedding files to infer_out_dir/speaker_outputs/embeddings\n", + "[NeMo I 2023-07-25 09:37:48 clustering_diarizer:287] Subsegmentation for embedding extraction: scale1, infer_out_dir/speaker_outputs/subsegments_scale1.json\n", + "[NeMo I 2023-07-25 09:37:48 clustering_diarizer:343] Extracting embeddings for Diarization\n", + "[NeMo I 2023-07-25 09:37:48 collections:298] Filtered duration for loading collection is 0.00 hours.\n", + "[NeMo I 2023-07-25 09:37:48 collections:299] Dataset loaded with 132 items, total duration of 0.04 hours.\n", + "[NeMo I 2023-07-25 09:37:48 collections:301] # 132 files loaded accounting to # 1 labels\n", + "[NeMo I 2023-07-25 09:37:48 clustering_diarizer:389] Saved embedding files to infer_out_dir/speaker_outputs/embeddings\n", + "[NeMo I 2023-07-25 09:37:48 clustering_diarizer:287] Subsegmentation for embedding extraction: scale2, infer_out_dir/speaker_outputs/subsegments_scale2.json\n", + "[NeMo I 2023-07-25 09:37:48 clustering_diarizer:343] Extracting embeddings for Diarization\n", + "[NeMo I 2023-07-25 09:37:48 collections:298] Filtered duration for loading collection is 0.00 hours.\n", + "[NeMo I 2023-07-25 09:37:48 collections:299] Dataset loaded with 166 items, total duration of 0.04 hours.\n", + "[NeMo I 2023-07-25 09:37:48 collections:301] # 166 files loaded accounting to # 1 labels\n", + "[NeMo I 2023-07-25 09:37:48 clustering_diarizer:389] Saved embedding files to infer_out_dir/speaker_outputs/embeddings\n", + "[NeMo I 2023-07-25 09:37:48 clustering_diarizer:287] Subsegmentation for embedding extraction: scale3, infer_out_dir/speaker_outputs/subsegments_scale3.json\n", + "[NeMo I 2023-07-25 09:37:48 clustering_diarizer:343] Extracting embeddings for Diarization\n", + "[NeMo I 2023-07-25 09:37:48 collections:298] Filtered duration for loading collection is 0.00 hours.\n", + "[NeMo I 2023-07-25 09:37:48 collections:299] Dataset loaded with 222 items, total duration of 0.04 hours.\n", + "[NeMo I 2023-07-25 09:37:48 collections:301] # 222 files loaded accounting to # 1 labels\n", + "[NeMo I 2023-07-25 09:37:49 clustering_diarizer:389] Saved embedding files to infer_out_dir/speaker_outputs/embeddings\n", + "[NeMo I 2023-07-25 09:37:49 clustering_diarizer:287] Subsegmentation for embedding extraction: scale4, infer_out_dir/speaker_outputs/subsegments_scale4.json\n", + "[NeMo I 2023-07-25 09:37:49 clustering_diarizer:343] Extracting embeddings for Diarization\n", + "[NeMo I 2023-07-25 09:37:49 collections:298] Filtered duration for loading collection is 0.00 hours.\n", + "[NeMo I 2023-07-25 09:37:49 collections:299] Dataset loaded with 343 items, total duration of 0.05 hours.\n", + "[NeMo I 2023-07-25 09:37:49 collections:301] # 343 files loaded accounting to # 1 labels\n", + "[NeMo I 2023-07-25 09:37:49 clustering_diarizer:389] Saved embedding files to infer_out_dir/speaker_outputs/embeddings\n", + "[NeMo I 2023-07-25 09:37:50 clustering_diarizer:464] Outputs are saved in /home/chainyo/wordcab-transcribe/infer_out_dir directory\n" + ] + } + ], + "source": [ + "clus_diar_model.diarize()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## VAD" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import List, Optional, Tuple, Union\n", + "\n", + "import torch\n", + "import torchaudio\n", + "from faster_whisper.vad import VadOptions, get_speech_timestamps\n", + "\n", + "\n", + "class VadService:\n", + " \"\"\"VAD Service for audio files.\"\"\"\n", + "\n", + " def __init__(self) -> None:\n", + " \"\"\"Initialize the VAD Service.\"\"\"\n", + " self.sample_rate = 16000\n", + " self.options = VadOptions(\n", + " threshold=0.5,\n", + " min_speech_duration_ms=250,\n", + " max_speech_duration_s=30,\n", + " min_silence_duration_ms=100,\n", + " window_size_samples=512,\n", + " speech_pad_ms=30,\n", + " )\n", + "\n", + " def __call__(\n", + " self, waveform: torch.Tensor, group_timestamps: Optional[bool] = True\n", + " ) -> Tuple[Union[List[dict], List[List[dict]]], torch.Tensor]:\n", + " \"\"\"\n", + " Use the VAD model to get the speech timestamps. Dual channel pipeline.\n", + "\n", + " Args:\n", + " waveform (torch.Tensor): Audio tensor.\n", + " group_timestamps (Optional[bool], optional): Group timestamps. Defaults to True.\n", + "\n", + " Returns:\n", + " Tuple[Union[List[dict], List[List[dict]]], torch.Tensor]: Speech timestamps and audio tensor.\n", + " \"\"\"\n", + " if waveform.size(0) == 1:\n", + " waveform = waveform.squeeze(0)\n", + "\n", + " speech_timestamps = get_speech_timestamps(\n", + " audio=waveform, vad_options=self.options\n", + " )\n", + "\n", + " _speech_timestamps_list = [\n", + " {\"start\": ts[\"start\"], \"end\": ts[\"end\"]} for ts in speech_timestamps\n", + " ]\n", + "\n", + " if group_timestamps:\n", + " speech_timestamps_list = self.group_timestamps(_speech_timestamps_list)\n", + " else:\n", + " speech_timestamps_list = _speech_timestamps_list\n", + "\n", + " return speech_timestamps_list, waveform\n", + "\n", + " def group_timestamps(\n", + " self, timestamps: List[dict], threshold: Optional[float] = 3.0\n", + " ) -> List[List[dict]]:\n", + " \"\"\"\n", + " Group timestamps based on a threshold.\n", + "\n", + " Args:\n", + " timestamps (List[dict]): List of timestamps.\n", + " threshold (float, optional): Threshold to use for grouping. Defaults to 3.0.\n", + "\n", + " Returns:\n", + " List[List[dict]]: List of grouped timestamps.\n", + " \"\"\"\n", + " grouped_segments = [[]]\n", + "\n", + " for i in range(len(timestamps)):\n", + " if (\n", + " i > 0\n", + " and (timestamps[i][\"start\"] - timestamps[i - 1][\"end\"]) > threshold\n", + " ):\n", + " grouped_segments.append([])\n", + "\n", + " grouped_segments[-1].append(timestamps[i])\n", + "\n", + " return grouped_segments\n", + "\n", + " def save_audio(self, filepath: str, audio: torch.Tensor) -> None:\n", + " \"\"\"\n", + " Save audio tensor to file.\n", + "\n", + " Args:\n", + " filepath (str): Path to save the audio file.\n", + " audio (torch.Tensor): Audio tensor.\n", + " \"\"\"\n", + " torchaudio.save(\n", + " filepath, audio.unsqueeze(0), self.sample_rate, bits_per_sample=16\n", + " )\n", + "\n", + "def read_audio(filepath: str, sample_rate: int = 16000) -> Tuple[torch.Tensor, float]:\n", + " \"\"\"\n", + " Read an audio file and return the audio tensor.\n", + "\n", + " Args:\n", + " filepath (str): Path to the audio file.\n", + " sample_rate (int): The sample rate of the audio file. Defaults to 16000.\n", + "\n", + " Returns:\n", + " Tuple[torch.Tensor, float]: The audio tensor and the audio duration.\n", + " \"\"\"\n", + " wav, sr = torchaudio.load(filepath)\n", + "\n", + " if wav.size(0) > 1:\n", + " wav = wav.mean(dim=0, keepdim=True)\n", + "\n", + " if sr != sample_rate:\n", + " transform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)\n", + " wav = transform(wav)\n", + " sr = sample_rate\n", + "\n", + " audio_duration = float(wav.shape[1]) / sample_rate\n", + "\n", + " return wav.squeeze(0), audio_duration\n", + "\n", + "def sr2s(v: int) -> float:\n", + " \"\"\"\n", + " Convert milliseconds to seconds.\n", + "\n", + " Args:\n", + " v (int): Value in milliseconds.\n", + "\n", + " Returns:\n", + " float: Value in seconds.\n", + " \"\"\"\n", + " return v / 16000" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "42\n", + "Start: 12.514, End: 12.99\n", + "Start: 13.25, End: 14.366\n", + "Start: 15.042, End: 16.862\n", + "Start: 17.762, End: 18.814\n", + "Start: 19.426, End: 20.766\n", + "Start: 21.666, End: 24.83\n", + "Start: 26.178, End: 29.886\n", + "Start: 30.786, End: 33.022\n", + "Start: 34.146, End: 37.214\n", + "Start: 38.338, End: 40.318\n", + "Start: 41.218, End: 42.782\n", + "Start: 43.682, End: 44.318\n", + "Start: 45.73, End: 46.494\n", + "Start: 47.778, End: 50.366\n", + "Start: 51.106, End: 52.926\n", + "Start: 53.954, End: 55.582\n", + "Start: 55.682, End: 56.926\n", + "Start: 57.954, End: 60.286\n", + "Start: 61.154, End: 64.254\n", + "Start: 65.026, End: 67.614\n", + "Start: 68.418, End: 68.99\n", + "Start: 69.922, End: 71.55\n", + "Start: 72.578, End: 75.838\n", + "Start: 76.61, End: 77.918\n", + "Start: 78.562, End: 79.454\n", + "Start: 79.746, End: 81.086\n", + "Start: 82.05, End: 83.902\n", + "Start: 84.738, End: 86.462\n", + "Start: 87.586, End: 90.782\n", + "Start: 91.746, End: 96.542\n", + "Start: 97.73, End: 98.27\n", + "Start: 99.586, End: 100.03\n", + "Start: 100.162, End: 100.862\n", + "Start: 101.794, End: 103.454\n", + "Start: 104.898, End: 107.486\n", + "Start: 108.226, End: 109.854\n", + "Start: 114.274, End: 114.75\n", + "Start: 114.914, End: 116.19\n", + "Start: 117.154, End: 119.358\n", + "Start: 120.194, End: 120.67\n", + "Start: 120.834, End: 121.886\n", + "Start: 122.978, End: 127.102\n" + ] + } + ], + "source": [ + "waveform, _ = read_audio(\"./mono_file.wav\")\n", + "\n", + "vad_service = VadService()\n", + "\n", + "speech_ts, _ = vad_service(waveform, True)\n", + "for ts in speech_ts:\n", + " _ts = ts[0]\n", + " print(\n", + " f\"Start: {sr2s(_ts['start'])}, End: {sr2s(_ts['end'])}\"\n", + " )" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Segmentation" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "\n", + "def get_subsegments(segment_start: float, segment_end: float, window: float, shift: float) -> List[List[float]]:\n", + " \"\"\"\n", + " Return a list of subsegments based on the segment start and end time and the window and shift length.\n", + "\n", + " Args:\n", + " segment_start (float): Segment start time.\n", + " segment_end (float): Segment end time.\n", + " window (float): Window length.\n", + " shift (float): Shift length.\n", + "\n", + " Returns:\n", + " List[List[float]]: List of subsegments with start time and duration.\n", + " \"\"\"\n", + " start = segment_start\n", + " duration = segment_end - segment_start\n", + " base = math.ceil((duration - window) / shift)\n", + " \n", + " subsegments: List[List[float]] = []\n", + " slices = 1 if base < 0 else base + 1\n", + " for slice_id in range(slices):\n", + " end = start + window\n", + "\n", + " if end > segment_end:\n", + " end = segment_end\n", + "\n", + " subsegments.append([start, end - start])\n", + "\n", + " start = segment_start + (slice_id + 1) * shift\n", + "\n", + " return subsegments" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "def _run_segmentation(\n", + " vad_outputs: List[dict],\n", + " window: float,\n", + " shift: float,\n", + " min_subsegment_duration: float = 0.05,\n", + ") -> List[dict]:\n", + " \"\"\"\"\"\"\n", + " scale_segment = []\n", + " for segment in vad_outputs:\n", + " segment_start, segment_end = sr2s(segment[\"start\"]), sr2s(segment[\"end\"])\n", + " subsegments = get_subsegments(segment_start, segment_end, window, shift)\n", + "\n", + " for subsegment in subsegments:\n", + " start, duration = subsegment\n", + " if duration > min_subsegment_duration:\n", + " scale_segment.append({\"offset\": start, \"duration\": duration})\n", + "\n", + " return scale_segment" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2023-07-31 07:04:19 experimental:27] Module is experimental, not ready for production and is not fully supported. Use at your own risk.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-31 07:04:19 cloud:58] Found existing object /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.\n", + "[NeMo I 2023-07-31 07:04:19 cloud:64] Re-using file from: /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo\n", + "[NeMo I 2023-07-31 07:04:19 common:913] Instantiating model from pre-trained checkpoint\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2023-07-31 07:04:22 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath: /manifests/combined_fisher_swbd_voxceleb12_librispeech/train.json\n", + " sample_rate: 16000\n", + " labels: null\n", + " batch_size: 64\n", + " shuffle: true\n", + " is_tarred: false\n", + " tarred_audio_filepaths: null\n", + " tarred_shard_strategy: scatter\n", + " augmentor:\n", + " noise:\n", + " manifest_path: /manifests/noise/rir_noise_manifest.json\n", + " prob: 0.5\n", + " min_snr_db: 0\n", + " max_snr_db: 15\n", + " speed:\n", + " prob: 0.5\n", + " sr: 16000\n", + " resample_type: kaiser_fast\n", + " min_speed_rate: 0.95\n", + " max_speed_rate: 1.05\n", + " num_workers: 15\n", + " pin_memory: true\n", + " \n", + "[NeMo W 2023-07-31 07:04:22 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath: /manifests/combined_fisher_swbd_voxceleb12_librispeech/dev.json\n", + " sample_rate: 16000\n", + " labels: null\n", + " batch_size: 128\n", + " shuffle: false\n", + " num_workers: 15\n", + " pin_memory: true\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-31 07:04:22 features:291] PADDING: 16\n", + "[NeMo I 2023-07-31 07:04:24 save_restore_connector:249] Model EncDecSpeakerLabelModel was successfully restored from /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/titanet-l/11ba0924fdf87c049e339adbf6899d48/titanet-l.nemo.\n" + ] + } + ], + "source": [ + "from nemo.collections.asr.models import EncDecSpeakerLabelModel\n", + "\n", + "from torch.cuda.amp import autocast\n", + "from torch.utils.data import Dataset\n", + "\n", + "speaker_model = EncDecSpeakerLabelModel.from_pretrained(\n", + " model_name=\"titanet_large\", map_location=None\n", + ")\n", + "\n", + "\n", + "class AudioSegmentDataset(Dataset):\n", + " def __init__(self, waveform: torch.Tensor, segments: List[dict], sample_rate=16000) -> None:\n", + " self.waveform = waveform\n", + " self.segments = segments\n", + " self.sample_rate = sample_rate\n", + "\n", + " def __len__(self) -> int:\n", + " return len(self.segments)\n", + "\n", + " def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:\n", + " segment_info = self.segments[idx]\n", + " offset_samples = int(segment_info[\"offset\"] * self.sample_rate)\n", + " duration_samples = int(segment_info[\"duration\"] * self.sample_rate)\n", + "\n", + " segment = self.waveform[offset_samples:offset_samples + duration_samples]\n", + "\n", + " return segment, torch.tensor(segment.shape[0]).long()\n", + "\n", + "\n", + "def collate_fn(batch: List[Tuple[torch.Tensor, torch.Tensor]]):\n", + " \"\"\"\"\"\"\n", + " _, audio_lengths = zip(*batch)\n", + "\n", + " has_audio = audio_lengths[0] is not None\n", + " fixed_length = int(max(audio_lengths))\n", + "\n", + " audio_signal, new_audio_lengths = [], []\n", + " for sig, sig_len in batch:\n", + " if has_audio:\n", + " sig_len = sig_len.item()\n", + " chunck_len = sig_len - fixed_length\n", + "\n", + " if chunck_len < 0:\n", + " repeat = fixed_length // sig_len\n", + " rem = fixed_length % sig_len\n", + " sub = sig[-rem:] if rem > 0 else torch.tensor([])\n", + " rep_sig = torch.cat(repeat * [sig])\n", + " sig = torch.cat((rep_sig, sub))\n", + " new_audio_lengths.append(torch.tensor(fixed_length))\n", + "\n", + " audio_signal.append(sig)\n", + "\n", + " if has_audio:\n", + " audio_signal = torch.stack(audio_signal)\n", + " audio_lengths = torch.stack(new_audio_lengths)\n", + " else:\n", + " audio_signal, audio_lengths = None, None\n", + "\n", + " return audio_signal, audio_lengths\n", + "\n", + "\n", + "def _extract_embeddings(waveform: torch.Tensor, scale_segments: List[dict]):\n", + " \"\"\"\n", + " This method extracts speaker embeddings from segments passed through manifest_file\n", + " Optionally you may save the intermediate speaker embeddings for debugging or any use. \n", + " \"\"\"\n", + " all_embs = torch.empty([0])\n", + "\n", + " dataset = AudioSegmentDataset(waveform, scale_segments)\n", + " dataloader = torch.utils.data.DataLoader(\n", + " dataset, batch_size=64, shuffle=False, collate_fn=collate_fn\n", + " )\n", + "\n", + " for batch in dataloader:\n", + " _batch = [x.to(speaker_model.device) for x in batch]\n", + " audio_signal, audio_signal_len = _batch\n", + "\n", + " with autocast():\n", + " _, embeddings = speaker_model.forward(input_signal=audio_signal, input_signal_length=audio_signal_len)\n", + " embeddings = embeddings.view(-1, embeddings.shape[-1])\n", + " all_embs = torch.cat((all_embs, embeddings.cpu().detach()), dim=0)\n", + " del _batch, audio_signal, audio_signal_len, embeddings\n", + "\n", + " embeddings, time_stamps = [], []\n", + " for i, segment in enumerate(scale_segments):\n", + " if i == 0:\n", + " embeddings = all_embs[i].view(1, -1)\n", + " else:\n", + " embeddings = torch.cat((embeddings, all_embs[i].view(1, -1)))\n", + "\n", + " time_stamps.append([segment['offset'], segment['duration']])\n", + "\n", + " return embeddings, time_stamps" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clustering" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering\n", + "\n", + "\n", + "def get_contiguous_stamps(stamps: list):\n", + " \"\"\"\n", + " Return contiguous time stamps\n", + " \"\"\"\n", + " contiguous_stamps = []\n", + " for i in range(len(stamps) - 1):\n", + " start, end, speaker = stamps[i]\n", + " next_start, next_end, next_speaker = stamps[i + 1]\n", + "\n", + " if end > next_start:\n", + " avg = (next_start + end) / 2.0\n", + " stamps[i + 1] = (avg, next_end, next_speaker)\n", + " contiguous_stamps.append((start, avg, speaker))\n", + " else:\n", + " contiguous_stamps.append((start, end, speaker))\n", + "\n", + " start, end, speaker = stamps[-1]\n", + " contiguous_stamps.append((start, end, speaker))\n", + "\n", + " return contiguous_stamps\n", + "\n", + "\n", + "def merge_stamps(stamps: list):\n", + " \"\"\"\n", + " Merge time stamps of the same speaker.\n", + " \"\"\"\n", + " overlap_stamps = []\n", + " for i in range(len(stamps) - 1):\n", + " start, end, speaker = stamps[i]\n", + " next_start, next_end, next_speaker = stamps[i + 1]\n", + "\n", + " if end == next_start and speaker == next_speaker:\n", + " stamps[i + 1] = (start, next_end, next_speaker)\n", + " else:\n", + " overlap_stamps.append((start, end, speaker))\n", + "\n", + " start, end, speaker = stamps[-1]\n", + " overlap_stamps.append((start, end, speaker))\n", + "\n", + " return overlap_stamps\n", + "\n", + "\n", + "def perform_clustering(embs_and_timestamps, clustering_params):\n", + " \"\"\"\n", + " Performs spectral clustering on embeddings with time stamps generated from VAD output.\n", + " \"\"\"\n", + " speaker_clustering = SpeakerClustering(cuda=True)\n", + "\n", + " base_scale_idx = embs_and_timestamps[\"multiscale_segment_counts\"].shape[0] - 1\n", + " cluster_labels = speaker_clustering.forward_infer(\n", + " embeddings_in_scales=embs_and_timestamps[\"embeddings\"],\n", + " timestamps_in_scales=embs_and_timestamps[\"timestamps\"],\n", + " multiscale_segment_counts=embs_and_timestamps[\"multiscale_segment_counts\"],\n", + " multiscale_weights=embs_and_timestamps[\"multiscale_weights\"],\n", + " oracle_num_speakers=-1,\n", + " max_num_speakers=int(clustering_params[\"max_num_speakers\"]),\n", + " max_rp_threshold=float(clustering_params[\"max_rp_threshold\"]),\n", + " sparse_search_volume=int(clustering_params[\"sparse_search_volume\"]),\n", + " )\n", + "\n", + " del embs_and_timestamps\n", + " torch.cuda.empty_cache()\n", + "\n", + " timestamps = speaker_clustering.timestamps_in_scales[base_scale_idx]\n", + " cluster_labels = cluster_labels.cpu().numpy()\n", + " if len(cluster_labels) != timestamps.shape[0]:\n", + " raise ValueError(\"Mismatch of length between cluster_labels and timestamps.\")\n", + "\n", + " clustering_labels = []\n", + " for idx, label in enumerate(cluster_labels):\n", + " stt, end = timestamps[idx]\n", + " clustering_labels.append((float(stt), float(stt + end), int(label)))\n", + "\n", + " return clustering_labels" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Mapping between embeddings and timestamps" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from statistics import mode\n", + "from typing import List\n", + "\n", + "import numpy as np\n", + "import torch\n", + "\n", + "\n", + "def get_argmin_mat(timestamps_in_scales: List[torch.Tensor]) -> List[torch.Tensor]:\n", + " \"\"\"\n", + " Calculate the mapping between the base scale and other scales. A segment from a longer scale is\n", + " repeatedly mapped to a segment from a shorter scale or the base scale.\n", + "\n", + " Args:\n", + " timestamps_in_scales (list):\n", + " List containing timestamp tensors for each scale.\n", + " Each tensor has dimensions of (Number of base segments) x 2.\n", + "\n", + " Returns:\n", + " session_scale_mapping_list (list):\n", + " List containing argmin arrays indexed by scale index.\n", + " \"\"\"\n", + " scale_list = list(range(len(timestamps_in_scales)))\n", + " segment_anchor_list = [torch.mean(timestamps_in_scales[scale_idx], dim=1) for scale_idx in scale_list]\n", + "\n", + " base_scale_idx = max(scale_list)\n", + " base_scale_anchor = segment_anchor_list[base_scale_idx]\n", + " base_scale_anchor = base_scale_anchor.view(-1, 1)\n", + "\n", + " session_scale_mapping_list = []\n", + " for scale_idx in scale_list:\n", + " curr_scale_anchor = segment_anchor_list[scale_idx].view(1, -1)\n", + " distance = torch.abs(curr_scale_anchor - base_scale_anchor)\n", + " argmin_mat = torch.argmin(distance, dim=1)\n", + " session_scale_mapping_list.append(argmin_mat)\n", + "\n", + " return session_scale_mapping_list\n", + "\n", + "\n", + "def assign_labels_to_longer_segs(clustering_labels: list, session_scale_mapping_list: list, scale_n: int):\n", + " \"\"\"\n", + " In multi-scale speaker diarization system, clustering result is solely based on the base-scale (the shortest scale).\n", + " To calculate cluster-average speaker embeddings for each scale that are longer than the base-scale, this function assigns\n", + " clustering results for the base-scale to the longer scales by measuring the distance between subsegment timestamps in the\n", + " base-scale and non-base-scales.\n", + "\n", + " Args:\n", + " base_clus_label_dict (dict):\n", + " Dictionary containing clustering results for base-scale segments. Indexed by `uniq_id` string.\n", + " session_scale_mapping_dict (dict):\n", + " Dictionary containing multiscale mapping information for each session. Indexed by `uniq_id` string.\n", + "\n", + " Returns:\n", + " all_scale_clus_label_dict (dict):\n", + " Dictionary containing clustering labels of all scales. Indexed by scale_index in integer format.\n", + "\n", + " \"\"\"\n", + " base_scale_clus_label = np.array([x[-1] for x in clustering_labels])\n", + " \n", + " all_scale_clus_label_dict = {}\n", + " all_scale_clus_label_dict[scale_n - 1] = base_scale_clus_label\n", + "\n", + " for scale_index, scale_mapping_tensor in enumerate(session_scale_mapping_list[:-1]):\n", + " new_clus_label = []\n", + " max_index = max(scale_mapping_tensor)\n", + "\n", + " for seg_idx in range(max_index + 1):\n", + " if seg_idx in scale_mapping_tensor:\n", + " seg_clus_label = mode(base_scale_clus_label[scale_mapping_tensor == seg_idx])\n", + " else:\n", + " seg_clus_label = 0 if len(new_clus_label) == 0 else new_clus_label[-1]\n", + "\n", + " new_clus_label.append(seg_clus_label)\n", + "\n", + " all_scale_clus_label_dict[scale_index] = new_clus_label\n", + "\n", + " return all_scale_clus_label_dict\n", + "\n", + "\n", + "# Check https://github.com/NVIDIA/NeMo/blob/2cc09425aba3e9b3cfdba43a3188eaef58227055/nemo/collections/asr/models/msdd_models.py#L756\n", + "def get_cluster_avg_embs(\n", + " emb_scale_seq_dict: dict,\n", + " clustering_labels: list,\n", + " session_scale_mapping_list: list,\n", + " scale_n: int,\n", + " max_num_speakers: int,\n", + "):\n", + " \"\"\"\n", + " MSDD requires cluster-average speaker embedding vectors for each scale. This function calculates an average embedding vector for each cluster (speaker)\n", + " and each scale.\n", + "\n", + " Args:\n", + " emb_scale_seq_dict (dict):\n", + " Dictionary containing embedding sequence for each scale. Keys are scale index in integer.\n", + " clus_labels (list):\n", + " Clustering results from clustering diarizer including all the sessions provided in input manifest files.\n", + " session_scale_mapping_dict (list):\n", + " List containing argmin arrays indexed by scale index.\n", + "\n", + " Returns:\n", + " emb_sess_avg_dict (dict):\n", + " Dictionary containing speaker mapping information and cluster-average speaker embedding vector.\n", + " Each session-level dictionary is indexed by scale index in integer.\n", + " output_clus_label_dict (dict):\n", + " Subegmentation timestamps in float type and Clustering result in integer type. Indexed by `uniq_id` keys.\n", + " \"\"\" \n", + " embeddings_session_average_dict = {}\n", + "\n", + " all_scale_clus_label_dict = assign_labels_to_longer_segs(\n", + " clustering_labels, session_scale_mapping_list, scale_n\n", + " )\n", + " \n", + " for scale_index, embeddings_tensor in emb_scale_seq_dict.items():\n", + " clustering_labels_list = all_scale_clus_label_dict[scale_index]\n", + " speaker_set = set(clustering_labels_list)\n", + "\n", + " clustering_labels_tensor = torch.Tensor(clustering_labels_list)\n", + " average_embeddings = torch.zeros(embeddings_tensor[0].shape[0], max_num_speakers)\n", + " for speaker_idx in speaker_set:\n", + " selected_embeddings = embeddings_tensor[clustering_labels_tensor == speaker_idx]\n", + " average_embeddings[:, speaker_idx] = torch.mean(selected_embeddings, dim=0)\n", + "\n", + " embeddings_session_average_dict[scale_index] = average_embeddings\n", + "\n", + " return embeddings_session_average_dict\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MSDD Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "from itertools import combinations\n", + "from typing import Dict\n", + "\n", + "\n", + "class AudioMSDDDataset(Dataset):\n", + " def __init__(\n", + " self,\n", + " emb_sess_avg_dict: Dict[str, torch.Tensor],\n", + " emb_scale_seq_dict: Dict[str, torch.Tensor],\n", + " clustering_labels: Dict[str, torch.Tensor],\n", + " sess_scale_mapping_list: List[torch.Tensor],\n", + " scale_n: int,\n", + " ) -> None:\n", + " self.emb_dict = emb_sess_avg_dict\n", + " self.emb_seq = emb_scale_seq_dict\n", + " self.clus_label_list = clustering_labels\n", + " self.sess_scale_mapping = sess_scale_mapping_list\n", + " self.scale_n = scale_n\n", + "\n", + " self.clus_speaker_digits = sorted(list(set([x[-1] for x in self.clus_label_list])))\n", + " if len(self.clus_speaker_digits) <= 2:\n", + " self.speaker_combinations = [(0, 1)]\n", + " else:\n", + " self.speaker_combinations = [x for x in combinations(self.clus_speaker_digits, 2)]\n", + "\n", + " def __len__(self) -> int:\n", + " return 1\n", + "\n", + " def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:\n", + " _avg_embs = torch.stack(\n", + " [\n", + " self.emb_dict[scale_index] \n", + " for scale_index in range(self.scale_n)\n", + " ]\n", + " ) # (scale_n, num_segments, max_num_speakers)\n", + "\n", + " selected_speakers = torch.tensor(self.speaker_combinations).flatten()\n", + " avg_embs = _avg_embs[:, :, selected_speakers]\n", + " \n", + "\n", + " if avg_embs.shape[2] > 2:\n", + " raise ValueError(\n", + " f\" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to self.max_num_speakers {2}\"\n", + " )\n", + "\n", + " feats = []\n", + " for scale_index in range(self.scale_n):\n", + " repeat_mat = self.sess_scale_mapping[scale_index]\n", + " feats.append(self.emb_seq[scale_index][repeat_mat, :])\n", + "\n", + " features = torch.stack(feats).permute(1, 0, 2)\n", + " features_length = features.shape[0]\n", + "\n", + " targets = torch.zeros(features_length, 2)\n", + "\n", + " return features, features_length, targets, avg_embs\n", + "\n", + "\n", + "def msdd_infer_collate_fn(batch):\n", + " \"\"\"\n", + " Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings.\n", + "\n", + " Args:\n", + " batch (tuple):\n", + " Batch tuple containing feats, feats_len, targets and ms_avg_embs.\n", + " Returns:\n", + " feats (torch.tensor):\n", + " Collated speaker embedding with unified length.\n", + " feats_len (torch.tensor):\n", + " The actual length of each embedding sequence without zero padding.\n", + " targets (torch.tensor):\n", + " Groundtruth Speaker label for the given input embedding sequence.\n", + " ms_avg_embs (torch.tensor):\n", + " Cluster-average speaker embedding vectors.\n", + " \"\"\"\n", + "\n", + " packed_batch = list(zip(*batch))\n", + " _, feats_len, targets, _ = packed_batch\n", + " max_audio_len = max(feats_len)\n", + " max_target_len = max([x.shape[0] for x in targets])\n", + "\n", + " feats_list, flen_list, targets_list, ms_avg_embs_list = [], [], [], []\n", + " for feature, feat_len, target, ivector in batch:\n", + " flen_list.append(feat_len)\n", + " ms_avg_embs_list.append(ivector)\n", + "\n", + " if feat_len < max_audio_len:\n", + " feats_list.append(\n", + " torch.nn.functional.pad(feature, (0, 0, 0, 0, 0, max_audio_len - feat_len))\n", + " )\n", + " targets_list.append(\n", + " torch.nn.functional.pad(target, (0, 0, 0, max_target_len - target.shape[0]))\n", + " )\n", + " else:\n", + " targets_list.append(target.clone().detach())\n", + " feats_list.append(feature.clone().detach())\n", + "\n", + " return (\n", + " torch.stack(feats_list), # Features\n", + " torch.tensor(flen_list), # Features length\n", + " torch.stack(targets_list), # Targets\n", + " torch.stack(ms_avg_embs_list), # Cluster-average embeddings\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-31 08:26:51 cloud:58] Found existing object /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.\n", + "[NeMo I 2023-07-31 08:26:51 cloud:64] Re-using file from: /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo\n", + "[NeMo I 2023-07-31 08:26:51 common:913] Instantiating model from pre-trained checkpoint\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[NeMo W 2023-07-31 08:26:52 modelPT:161] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.\n", + " Train config : \n", + " manifest_filepath: null\n", + " emb_dir: null\n", + " sample_rate: 16000\n", + " num_spks: 2\n", + " soft_label_thres: 0.5\n", + " labels: null\n", + " batch_size: 15\n", + " emb_batch_size: 0\n", + " shuffle: true\n", + " \n", + "[NeMo W 2023-07-31 08:26:52 modelPT:168] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). \n", + " Validation config : \n", + " manifest_filepath: null\n", + " emb_dir: null\n", + " sample_rate: 16000\n", + " num_spks: 2\n", + " soft_label_thres: 0.5\n", + " labels: null\n", + " batch_size: 15\n", + " emb_batch_size: 0\n", + " shuffle: false\n", + " \n", + "[NeMo W 2023-07-31 08:26:52 modelPT:174] Please call the ModelPT.setup_test_data() or ModelPT.setup_multiple_test_data() method and provide a valid configuration file to setup the test data loader(s).\n", + " Test config : \n", + " manifest_filepath: null\n", + " emb_dir: null\n", + " sample_rate: 16000\n", + " num_spks: 2\n", + " soft_label_thres: 0.5\n", + " labels: null\n", + " batch_size: 15\n", + " emb_batch_size: 0\n", + " shuffle: false\n", + " seq_eval_mode: false\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[NeMo I 2023-07-31 08:26:52 features:291] PADDING: 16\n", + "[NeMo I 2023-07-31 08:26:52 features:291] PADDING: 16\n", + "[NeMo I 2023-07-31 08:26:53 save_restore_connector:249] Model EncDecDiarLabelModel was successfully restored from /home/chainyo/.cache/torch/NeMo/NeMo_1.19.1/diar_msdd_telephonic/3c3697a0a46f945574fa407149975a13/diar_msdd_telephonic.nemo.\n" + ] + }, + { + "data": { + "text/plain": [ + "EncDecDiarLabelModel(\n", + " (preprocessor): AudioToMelSpectrogramPreprocessor(\n", + " (featurizer): FilterbankFeatures()\n", + " )\n", + " (msdd): MSDD_module(\n", + " (softmax): Softmax(dim=2)\n", + " (cos_dist): CosineSimilarity()\n", + " (lstm): LSTM(256, 256, num_layers=2, batch_first=True, dropout=0.5, bidirectional=True)\n", + " (conv): ModuleList(\n", + " (0): ConvLayer(\n", + " (cnn): Sequential(\n", + " (0): Conv2d(1, 16, kernel_size=(15, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (1): ConvLayer(\n", + " (cnn): Sequential(\n", + " (0): Conv2d(1, 16, kernel_size=(16, 1), stride=(1, 1))\n", + " (1): ReLU()\n", + " (2): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " )\n", + " (conv_bn): ModuleList(\n", + " (0-1): 2 x BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)\n", + " )\n", + " (conv_to_linear): Linear(in_features=3072, out_features=256, bias=True)\n", + " (linear_to_weights): Linear(in_features=256, out_features=5, bias=True)\n", + " (hidden_to_spks): Linear(in_features=512, out_features=2, bias=True)\n", + " (dist_to_emb): Linear(in_features=10, out_features=256, bias=True)\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " (_speaker_model): EncDecSpeakerLabelModel(\n", + " (loss): AngularSoftmaxLoss()\n", + " (eval_loss): AngularSoftmaxLoss()\n", + " (_accuracy): TopKClassificationAccuracy()\n", + " (preprocessor): AudioToMelSpectrogramPreprocessor(\n", + " (featurizer): FilterbankFeatures()\n", + " )\n", + " (encoder): ConvASREncoder(\n", + " (encoder): Sequential(\n", + " (0): JasperBlock(\n", + " (mconv): ModuleList(\n", + " (0): MaskedConv1d(\n", + " (conv): Conv1d(80, 80, kernel_size=(3,), stride=(1,), padding=(1,), groups=80, bias=False)\n", + " )\n", + " (1): MaskedConv1d(\n", + " (conv): Conv1d(80, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (2): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): SqueezeExcite(\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=1024, out_features=128, bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Linear(in_features=128, out_features=1024, bias=False)\n", + " )\n", + " (gap): AdaptiveAvgPool1d(output_size=1)\n", + " )\n", + " )\n", + " (mout): Sequential(\n", + " (0): ReLU(inplace=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " (1): JasperBlock(\n", + " (mconv): ModuleList(\n", + " (0): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(7,), stride=(1,), padding=(3,), groups=1024, bias=False)\n", + " )\n", + " (1): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (2): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): ReLU(inplace=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " (5): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(7,), stride=(1,), padding=(3,), groups=1024, bias=False)\n", + " )\n", + " (6): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (7): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (8): ReLU(inplace=True)\n", + " (9): Dropout(p=0.1, inplace=False)\n", + " (10): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(7,), stride=(1,), padding=(3,), groups=1024, bias=False)\n", + " )\n", + " (11): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (12): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (13): SqueezeExcite(\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=1024, out_features=128, bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Linear(in_features=128, out_features=1024, bias=False)\n", + " )\n", + " (gap): AdaptiveAvgPool1d(output_size=1)\n", + " )\n", + " )\n", + " (res): ModuleList(\n", + " (0): ModuleList(\n", + " (0): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (1): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (mout): Sequential(\n", + " (0): ReLU(inplace=True)\n", + " (1): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): JasperBlock(\n", + " (mconv): ModuleList(\n", + " (0): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(11,), stride=(1,), padding=(5,), groups=1024, bias=False)\n", + " )\n", + " (1): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (2): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): ReLU(inplace=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " (5): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(11,), stride=(1,), padding=(5,), groups=1024, bias=False)\n", + " )\n", + " (6): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (7): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (8): ReLU(inplace=True)\n", + " (9): Dropout(p=0.1, inplace=False)\n", + " (10): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(11,), stride=(1,), padding=(5,), groups=1024, bias=False)\n", + " )\n", + " (11): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (12): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (13): SqueezeExcite(\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=1024, out_features=128, bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Linear(in_features=128, out_features=1024, bias=False)\n", + " )\n", + " (gap): AdaptiveAvgPool1d(output_size=1)\n", + " )\n", + " )\n", + " (res): ModuleList(\n", + " (0): ModuleList(\n", + " (0): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (1): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (mout): Sequential(\n", + " (0): ReLU(inplace=True)\n", + " (1): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): JasperBlock(\n", + " (mconv): ModuleList(\n", + " (0): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(15,), stride=(1,), padding=(7,), groups=1024, bias=False)\n", + " )\n", + " (1): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (2): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): ReLU(inplace=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " (5): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(15,), stride=(1,), padding=(7,), groups=1024, bias=False)\n", + " )\n", + " (6): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (7): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (8): ReLU(inplace=True)\n", + " (9): Dropout(p=0.1, inplace=False)\n", + " (10): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(15,), stride=(1,), padding=(7,), groups=1024, bias=False)\n", + " )\n", + " (11): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (12): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (13): SqueezeExcite(\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=1024, out_features=128, bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Linear(in_features=128, out_features=1024, bias=False)\n", + " )\n", + " (gap): AdaptiveAvgPool1d(output_size=1)\n", + " )\n", + " )\n", + " (res): ModuleList(\n", + " (0): ModuleList(\n", + " (0): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (1): BatchNorm1d(1024, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (mout): Sequential(\n", + " (0): ReLU(inplace=True)\n", + " (1): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): JasperBlock(\n", + " (mconv): ModuleList(\n", + " (0): MaskedConv1d(\n", + " (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), groups=1024, bias=False)\n", + " )\n", + " (1): MaskedConv1d(\n", + " (conv): Conv1d(1024, 3072, kernel_size=(1,), stride=(1,), bias=False)\n", + " )\n", + " (2): BatchNorm1d(3072, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)\n", + " (3): SqueezeExcite(\n", + " (fc): Sequential(\n", + " (0): Linear(in_features=3072, out_features=384, bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Linear(in_features=384, out_features=3072, bias=False)\n", + " )\n", + " (gap): AdaptiveAvgPool1d(output_size=1)\n", + " )\n", + " )\n", + " (mout): Sequential(\n", + " (0): ReLU(inplace=True)\n", + " (1): Dropout(p=0.0, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (decoder): SpeakerDecoder(\n", + " (_pooling): AttentivePoolLayer(\n", + " (attention_layer): Sequential(\n", + " (0): TDNNModule(\n", + " (conv_layer): Conv1d(9216, 128, kernel_size=(1,), stride=(1,))\n", + " (activation): ReLU()\n", + " (bn): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): Tanh()\n", + " (2): Conv1d(128, 3072, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (emb_layers): ModuleList(\n", + " (0): Sequential(\n", + " (0): BatchNorm1d(6144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (1): Conv1d(6144, 192, kernel_size=(1,), stride=(1,))\n", + " )\n", + " )\n", + " (final): Linear(in_features=192, out_features=16681, bias=False)\n", + " )\n", + " (_macro_accuracy): MulticlassAccuracy()\n", + " (spec_augmentation): SpectrogramAugmentation(\n", + " (spec_augment): SpecAugment()\n", + " )\n", + " )\n", + " )\n", + " (_accuracy_test): MultiBinaryAccuracy()\n", + " (_accuracy_train): MultiBinaryAccuracy()\n", + " (_accuracy_valid): MultiBinaryAccuracy()\n", + ")" + ] + }, + "execution_count": 47, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from nemo.collections.asr.models import EncDecDiarLabelModel\n", + "from omegaconf import OmegaConf\n", + "\n", + "\n", + "msdd_cfg = OmegaConf.create({\n", + " \"model_path\": \"diar_msdd_telephonic\",\n", + " \"parameters\": {\n", + " \"use_speaker_model_from_ckpt\": True,\n", + " \"infer_batch_size\": 25,\n", + " \"sigmoid_threshold\": [0.7],\n", + " \"seq_eval_mode\": False,\n", + " \"split_infer\": True,\n", + " \"diar_window_length\": 50,\n", + " \"overlap_infer_spk_limit\": 5,\n", + " }\n", + "})\n", + "# msdd_model = EncDecDiarLabelModel.from_config_dict(msdd_cfg)\n", + "msdd_model = EncDecDiarLabelModel.from_pretrained(model_name=msdd_cfg.model_path)\n", + "msdd_model.eval()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Map MSDD + Clustering" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "metadata": {}, + "outputs": [], + "source": [ + "def get_overlap_stamps(contiguous_stamps: List[str], overlap_speaker_index: List[str]):\n", + " \"\"\"\n", + " Generate timestamps that include overlap speech. Overlap-including timestamps are created based on the segments that are\n", + " created for clustering diarizer. Overlap speech is assigned to the existing speech segments in `cont_stamps`.\n", + "\n", + " Args:\n", + " cont_stamps (list):\n", + " Non-overlapping (single speaker per segment) diarization output in string format.\n", + " Each line contains the start and end time of segments and corresponding speaker labels.\n", + " ovl_spk_idx (list):\n", + " List containing segment index of the estimated overlapped speech. The start and end of segments are based on the\n", + " single-speaker (i.e., non-overlap-aware) RTTM generation.\n", + " Returns:\n", + " total_ovl_cont_list (list):\n", + " Rendered diarization output in string format. Each line contains the start and end time of segments and\n", + " corresponding speaker labels. This format is identical to `cont_stamps`.\n", + " \"\"\"\n", + " overlap_speaker_contiguous_list = [[] for _ in range(len(overlap_speaker_index))]\n", + " \n", + " for speaker_index in range(len(overlap_speaker_index)):\n", + " for index, segment in enumerate(contiguous_stamps):\n", + " start, end, _ = segment\n", + " if index in overlap_speaker_index[speaker_index]:\n", + " overlap_speaker_contiguous_list[speaker_index].append((start, end, speaker_index))\n", + "\n", + " total_overlap_contiguous_list = []\n", + "\n", + " for overlap_contiguous_list in overlap_speaker_contiguous_list:\n", + " if len(overlap_contiguous_list) > 0:\n", + " total_overlap_contiguous_list.extend(merge_stamps(overlap_contiguous_list))\n", + "\n", + " return total_overlap_contiguous_list\n", + "\n", + "\n", + "def generate_speaker_timestamps(\n", + " clustering_labels: List[Union[float, int]],\n", + " msdd_preds: torch.Tensor,\n", + " threshold: float = 0.7,\n", + " overlap_infer_speaker_limit: int = 5,\n", + " max_overlap_speakers: int = 2,\n", + ") -> Tuple[List[str], List[str]]:\n", + " '''\n", + " Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use clustering result for main speaker\n", + " labels and use timestamps from the predicted sigmoid values. In this function, the main speaker labels in `maj_labels` exist for\n", + " every subsegment steps while overlap speaker labels in `ovl_labels` only exist for segments where overlap-speech is occuring.\n", + "\n", + " Args:\n", + " clus_labels (list):\n", + " List containing integer-valued speaker clustering results.\n", + " msdd_preds (list):\n", + " List containing tensors of the predicted sigmoid values.\n", + " Each tensor has shape of: (Session length, estimated number of speakers).\n", + " params:\n", + " Parameters for generating RTTM output and evaluation. Parameters include:\n", + " infer_overlap (bool): If False, overlap-speech will not be detected.\n", + " use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. If False, only MSDD output\n", + " is used for constructing output RTTM files.\n", + " overlap_infer_spk_limit (int): Above this limit, overlap-speech detection is bypassed.\n", + " use_adaptive_thres (bool): Boolean that determines whehther to use adaptive_threshold depending on the estimated\n", + " number of speakers.\n", + " max_overlap_spks (int): Maximum number of overlap speakers detected. Default is 2.\n", + " threshold (float): Sigmoid threshold for MSDD output.\n", + "\n", + " Returns:\n", + " maj_labels (list):\n", + " List containing string-formated single-speaker speech segment timestamps and corresponding speaker labels.\n", + " Example: [..., '551.685 552.77 speaker_1', '552.99 554.43 speaker_0', '554.97 558.19 speaker_0', ...]\n", + " ovl_labels (list):\n", + " List containing string-formated additional overlapping speech segment timestamps and corresponding speaker labels.\n", + " Note that `ovl_labels` includes only overlapping speech that is not included in `maj_labels`.\n", + " Example: [..., '152.495 152.745 speaker_1', '372.71 373.085 speaker_0', '554.97 555.885 speaker_1', ...]\n", + " '''\n", + " estimated_num_of_spks = msdd_preds.shape[-1]\n", + " overlap_speaker_list = [[] for _ in range(estimated_num_of_spks)]\n", + " infer_overlap = estimated_num_of_spks < int(overlap_infer_speaker_limit)\n", + "\n", + " main_speaker_lines = []\n", + " _threshold = threshold - (estimated_num_of_spks - 2) * (threshold - 1) / (\n", + " overlap_infer_speaker_limit - 2\n", + " )\n", + "\n", + " for segment_index, cluster_label in enumerate(clustering_labels):\n", + " speaker_for_segment = (msdd_preds[0, segment_index] > _threshold).int().tolist()\n", + " softmax_predictions = msdd_preds[0, segment_index]\n", + "\n", + " main_speaker_index = torch.argmax(msdd_preds[0, segment_index]).item()\n", + "\n", + " if sum(speaker_for_segment) > 1 and infer_overlap:\n", + " index_array = torch.argsort(softmax_predictions, descending=True)\n", + "\n", + " for overlap_speaker_index in index_array[: max_overlap_speakers].tolist():\n", + " if overlap_speaker_index != int(main_speaker_index):\n", + " overlap_speaker_list[overlap_speaker_index].append(segment_index)\n", + "\n", + " main_speaker_lines.append((cluster_label[0], cluster_label[1], main_speaker_index))\n", + "\n", + " contiguous_stamps = get_contiguous_stamps(main_speaker_lines)\n", + " main_labels = merge_stamps(contiguous_stamps)\n", + "\n", + " overlap_labels = get_overlap_stamps(contiguous_stamps, overlap_speaker_list)\n", + "\n", + " return main_labels, overlap_labels\n", + "\n", + "\n", + "def make_rttm_with_overlap(\n", + " clustering_labels: List[Union[float, int]],\n", + " msdd_preds: torch.Tensor,\n", + "):\n", + " \"\"\"\n", + " \"\"\"\n", + " main_labels, overlap_labels = generate_speaker_timestamps(clustering_labels, msdd_preds)\n", + "\n", + " # _hypothesis_labels = main_labels + overlap_labels\n", + " _hypothesis_labels = main_labels\n", + " hypothesis_labels = sorted(_hypothesis_labels, key=lambda x: x[0])\n", + "\n", + " return hypothesis_labels" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Real Diarization process" + ] + }, + { + "cell_type": "code", + "execution_count": 127, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 1.5 0.75\n", + "1 1.25 0.625\n", + "2 1.0 0.5\n", + "3 0.75 0.375\n", + "4 0.5 0.25\n" + ] + } + ], + "source": [ + "max_num_speakers = 8\n", + "window_lengths, shift_lengths, multiscale_weights = (\n", + " [1.5, 1.25, 1.0, 0.75, 0.5],\n", + " [0.75, 0.625, 0.5, 0.375, 0.25],\n", + " [1, 1, 1, 1, 1],\n", + ")\n", + "scale_dict = {k: (w, s) for k, (w, s) in enumerate(zip(window_lengths, shift_lengths))}\n", + "\n", + "# VAD\n", + "waveform, _ = read_audio(\"./mono_file.wav\")\n", + "vad_service = VadService()\n", + "\n", + "vad_outputs, _ = vad_service(waveform, False)\n", + "\n", + "# Segmentation\n", + "all_embeddings, all_timestamps, all_segment_indexes = [], [], []\n", + "\n", + "scales = scale_dict.items()\n", + "for _, (window, shift) in scales:\n", + " scale_segments = _run_segmentation(vad_outputs, window, shift)\n", + "\n", + " _embeddings, _timestamps = _extract_embeddings(waveform, scale_segments)\n", + "\n", + " if len(_embeddings) != len(_timestamps):\n", + " raise ValueError(\"Mismatch of counts between embedding vectors and timestamps\")\n", + "\n", + " all_embeddings.append(_embeddings)\n", + " all_segment_indexes.append(_embeddings.shape[0])\n", + " all_timestamps.append(torch.tensor(_timestamps))\n", + "\n", + "multiscale_embeddings_and_timestamps = {\n", + " \"embeddings\": torch.cat(all_embeddings, dim=0),\n", + " \"timestamps\": torch.cat(all_timestamps, dim=0),\n", + " \"multiscale_segment_counts\": torch.tensor(all_segment_indexes),\n", + " \"multiscale_weights\": torch.tensor([1, 1, 1, 1, 1]).unsqueeze(0).float(),\n", + "}\n", + "\n", + "# Clustering\n", + "clustering_params = dict(\n", + " oracle_num_speakers=False,\n", + " max_num_speakers=max_num_speakers,\n", + " enhanced_count_thres=80,\n", + " max_rp_threshold=0.25,\n", + " sparse_search_volume=30,\n", + " maj_vote_spk_count=False,\n", + ")\n", + "clustering_labels = perform_clustering(\n", + " embs_and_timestamps=multiscale_embeddings_and_timestamps,\n", + " clustering_params=clustering_params,\n", + ")\n", + "\n", + "# Mapping between embeddings and timestamps on different scales\n", + "split_index = multiscale_embeddings_and_timestamps[\"multiscale_segment_counts\"].tolist()\n", + "embeddings_in_scales = list(torch.split(\n", + " multiscale_embeddings_and_timestamps[\"embeddings\"], split_index, dim=0\n", + "))\n", + "timestamps_in_scales = list(torch.split(\n", + " multiscale_embeddings_and_timestamps[\"timestamps\"], split_index, dim=0\n", + "))\n", + "session_scale_mapping_list = get_argmin_mat(timestamps_in_scales)\n", + "\n", + "scale_mapping_argmat, emb_scale_seq_dict = {}, {}\n", + "for scale_idx in range(len(session_scale_mapping_list)):\n", + " mapping_argmat = session_scale_mapping_list[scale_idx]\n", + " scale_mapping_argmat[scale_idx] = mapping_argmat\n", + "\n", + " emb_scale_seq_dict[scale_idx] = embeddings_in_scales[scale_idx]\n", + "\n", + "emb_sess_avg_dict = get_cluster_avg_embs(\n", + " emb_scale_seq_dict, clustering_labels, session_scale_mapping_list, len(scale_dict), max_num_speakers\n", + ")\n", + "\n", + "# MSDD algorithm\n", + "preds_list, targets_list, signal_lengths_list = [], [], []\n", + "dataset = AudioMSDDDataset(\n", + " emb_sess_avg_dict=emb_sess_avg_dict,\n", + " emb_scale_seq_dict=emb_scale_seq_dict,\n", + " sess_scale_mapping_list=session_scale_mapping_list,\n", + " clustering_labels=clustering_labels,\n", + " scale_n=len(scale_dict),\n", + ")\n", + "\n", + "dataloader = torch.utils.data.DataLoader(\n", + " dataset=dataset,\n", + " batch_size=1,\n", + " collate_fn=msdd_infer_collate_fn,\n", + " drop_last=False,\n", + " shuffle=False,\n", + " num_workers=0,\n", + " pin_memory=False,\n", + ")\n", + "\n", + "for batch in dataloader:\n", + " signals, signal_lengths, _, emb_vectors = batch\n", + "\n", + " # Convert data to float16\n", + " signals = signals.half().to(msdd_model.device)\n", + " signal_lengths = signal_lengths.half().to(msdd_model.device)\n", + " emb_vectors = emb_vectors.half().to(msdd_model.device)\n", + "\n", + " with autocast():\n", + " _preds, scale_weights = msdd_model.forward_infer(\n", + " input_signal=signals,\n", + " input_signal_length=signal_lengths,\n", + " emb_vectors=emb_vectors,\n", + " targets=None,\n", + " )\n", + " _preds = _preds.cpu().detach()\n", + " scale_weights = scale_weights.cpu().detach()\n", + "\n", + " max_pred_length = max(_preds.shape[1], 0)\n", + " preds = torch.zeros(_preds.shape[0], max_pred_length, _preds.shape[2])\n", + " targets = torch.zeros(_preds.shape[0], max_pred_length, _preds.shape[2])\n", + "\n", + " preds[:, : _preds.shape[1], :] = _preds\n", + "\n", + "all_hypothesis = make_rttm_with_overlap(clustering_labels, preds)\n", + "\n", + "contiguous_cluster = get_contiguous_stamps(clustering_labels)\n", + "last_cluster = merge_stamps(contiguous_cluster)\n", + "\n", + "# print(len(all_hypothesis), len(last_cluster))\n", + "# for h, c in zip(all_hypothesis, last_cluster):\n", + "# start, end, speaker = h\n", + "# print(f\"{h} | {c}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_config.py b/tests/test_config.py index 6fdac51..5566fdb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -29,15 +29,13 @@ def default_settings() -> OrderedDict: description="💬 ASR FastAPI server using faster-whisper and NVIDIA NeMo.", api_prefix="/api/v1", debug=True, - batch_size=1, - max_wait=0.1, whisper_model="large-v2", compute_type="float16", extra_languages=["he"], extra_languages_model_paths={"he": "path/to/model"}, - nemo_domain_type="general", - nemo_storage_path="nemo_storage", - nemo_output_path="nemo_outputs", + window_lengths=[1.5, 1.25, 1.0, 0.75, 0.5], + shift_lengths=[0.75, 0.625, 0.5, 0.375, 0.25], + multiscale_weights=[1.0, 1.0, 1.0, 1.0, 1.0], asr_type="async", audio_file_endpoint=True, audio_url_endpoint=True, @@ -66,17 +64,14 @@ def test_config() -> None: assert settings.api_prefix == "/api/v1" assert settings.debug is True - assert settings.batch_size == 1 - assert settings.max_wait == 0.1 - assert settings.whisper_model == "large-v2" assert settings.compute_type == "float16" - # assert settings.extra_languages == ["he"] - # assert settings.extra_languages_model_paths == {"he": ""} + assert settings.extra_languages == [""] + assert settings.extra_languages_model_paths == {"": ""} - assert settings.nemo_domain_type == "telephonic" - assert settings.nemo_storage_path == "nemo_storage" - assert settings.nemo_output_path == "nemo_outputs" + assert settings.window_lengths == [1.5, 1.25, 1.0, 0.75, 0.5] + assert settings.shift_lengths == [0.75, 0.625, 0.5, 0.375, 0.25] + assert settings.multiscale_weights == [1.0, 1.0, 1.0, 1.0, 1.0] assert settings.asr_type == "async" @@ -120,19 +115,6 @@ def test_general_parameters_validator(default_settings: dict) -> None: Settings(**wrong_api_prefix) -def test_batch_request_parameters_validator(default_settings: dict) -> None: - """Test batch request parameters validator.""" - wrong_batch_size = default_settings.copy() - wrong_batch_size["batch_size"] = 0 - with pytest.raises(ValueError): - Settings(**wrong_batch_size) - - wrong_max_wait = default_settings.copy() - wrong_max_wait["max_wait"] = -1 - with pytest.raises(ValueError): - Settings(**wrong_max_wait) - - def test_whisper_model_validator(default_settings: dict) -> None: """Test whisper model validator.""" wrong_whisper_model = default_settings.copy() @@ -153,13 +135,6 @@ def test_compute_type_validator(default_settings: dict) -> None: Settings(**default_settings) -def test_nemo_domain_type_validator(default_settings: dict) -> None: - """Test nemo domain type validator.""" - default_settings["nemo_domain_type"] = "invalid_domain_type" - with pytest.raises(ValueError): - Settings(**default_settings) - - def test_asr_type_validator(default_settings: dict) -> None: """Test asr type validator.""" default_settings["asr_type"] = "invalid_asr_type" diff --git a/tests/utils/test_load_nemo_config.py b/tests/utils/test_load_nemo_config.py deleted file mode 100644 index 7cfc690..0000000 --- a/tests/utils/test_load_nemo_config.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2023 The Wordcab Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests the load_nemo_config function.""" - -from pathlib import Path - -import pytest -from omegaconf import OmegaConf - -from wordcab_transcribe.utils import load_nemo_config - - -@pytest.mark.parametrize("domain_type", ["general", "meeting", "telephonic"]) -def test_load_nemo_config(domain_type: str): - """Test the load_nemo_config function.""" - cfg, _ = load_nemo_config( - domain_type, - "storage/path", - "output/path", - "cpu", - 0, - ) - - cfg_path = f"config/nemo/diar_infer_{domain_type}.yaml" - with open(cfg_path) as f: - data = OmegaConf.load(f) - - assert cfg != data - - assert cfg.num_workers == 0 - assert cfg.diarizer.manifest_filepath == str( - Path.cwd() / "storage/path/infer_manifest_0.json" - ) - assert cfg.diarizer.out_dir == str(Path.cwd() / "output/path") diff --git a/wordcab_transcribe/config.py b/wordcab_transcribe/config.py index 4f65907..c9aecf2 100644 --- a/wordcab_transcribe/config.py +++ b/wordcab_transcribe/config.py @@ -34,19 +34,16 @@ class Settings: description: str api_prefix: str debug: bool - # Batch configuration - batch_size: int - max_wait: float # Models configuration # Whisper whisper_model: str compute_type: str extra_languages: List[str] extra_languages_model_paths: Dict[str, str] - # NVIDIA NeMo - nemo_domain_type: str - nemo_storage_path: str - nemo_output_path: str + # Diarization + window_lengths: List[float] + shift_lengths: List[float] + multiscale_weights: List[float] # ASR type configuration asr_type: str # Endpoints configuration @@ -107,26 +104,6 @@ def api_prefix_must_not_be_none(cls, value: str): # noqa: B902, N805 return value - @field_validator("batch_size") - def batch_size_must_be_positive(cls, value: int): # noqa: B902, N805 - """Check that the batch_size is positive.""" - if value <= 0: - raise ValueError( - "`batch_size` must be positive, please verify the `.env` file." - ) - - return value - - @field_validator("max_wait") - def max_wait_must_be_positive(cls, value: float): # noqa: B902, N805 - """Check that the max_wait is positive.""" - if value <= 0: - raise ValueError( - "`max_wait` must be positive, please verify the `.env` file." - ) - - return value - @field_validator("whisper_model") def whisper_model_must_be_valid(cls, value: str): # noqa: B902, N805 """Check that the model name is valid. It can be a local path or a model name.""" @@ -152,17 +129,6 @@ def compute_type_must_be_valid(cls, value: str): # noqa: B902, N805 return value - @field_validator("nemo_domain_type") - def nemo_domain_type_must_be_valid(cls, value: str): # noqa: B902, N805 - """Check that the model precision is valid.""" - if value not in {"general", "telephonic", "meeting"}: - raise ValueError( - f"{value} is not a valid domain type. " - "Choose one of general, telephonic, meeting." - ) - - return value - @field_validator("asr_type") def asr_type_must_be_valid(cls, value: str): # noqa: B902, N805 """Check that the ASR type is valid.""" @@ -224,6 +190,16 @@ def __post_init__(self): "You can generate a new key with `openssl rand -hex 32`." ) + if ( + len(self.window_lengths) + != len(self.shift_lengths) + != len(self.multiscale_weights) + ): + raise ValueError( + f"Length of window_lengths, shift_lengths and multiscale_weights must be the same.\n" + f"Found: {len(self.window_lengths)}, {len(self.shift_lengths)}, {len(self.multiscale_weights)}" + ) + load_dotenv() @@ -234,6 +210,25 @@ def __post_init__(self): else: extra_languages = [] +# Diarization scales +_window_lengths = getenv("WINDOW_LENGTHS") +if _window_lengths is not None: + window_lengths = [float(x) for x in _window_lengths.split(",")] +else: + window_lengths = [1.5, 1.25, 1.0, 0.75, 0.5] + +_shift_lengths = getenv("SHIFT_LENGTHS") +if _shift_lengths is not None: + shift_lengths = [float(x) for x in _shift_lengths.split(",")] +else: + shift_lengths = [0.75, 0.625, 0.5, 0.375, 0.25] + +_multiscale_weights = getenv("MULTISCALE_WEIGHTS") +if _multiscale_weights is not None: + multiscale_weights = [float(x) for x in _multiscale_weights.split(",")] +else: + multiscale_weights = [1.0, 1.0, 1.0, 1.0, 1.0] + settings = Settings( # General configuration project_name=getenv("PROJECT_NAME", "Wordcab Transcribe"), @@ -243,9 +238,6 @@ def __post_init__(self): ), api_prefix=getenv("API_PREFIX", "/api/v1"), debug=getenv("DEBUG", True), - # Batch configuration - batch_size=getenv("BATCH_SIZE", 1), - max_wait=getenv("MAX_WAIT", 0.1), # Models configuration # Whisper whisper_model=getenv("WHISPER_MODEL", "large-v2"), @@ -253,9 +245,9 @@ def __post_init__(self): extra_languages=extra_languages, extra_languages_model_paths={lang: "" for lang in extra_languages}, # NeMo - nemo_domain_type=getenv("NEMO_DOMAIN_TYPE", "general"), - nemo_storage_path=getenv("NEMO_STORAGE_PATH", "nemo_storage"), - nemo_output_path=getenv("NEMO_OUTPUT_PATH", "nemo_outputs"), + window_lengths=window_lengths, + shift_lengths=shift_lengths, + multiscale_weights=multiscale_weights, # ASR type asr_type=getenv("ASR_TYPE", "async"), # Endpoints configuration diff --git a/wordcab_transcribe/services/asr_service.py b/wordcab_transcribe/services/asr_service.py index 190462b..b990bf6 100644 --- a/wordcab_transcribe/services/asr_service.py +++ b/wordcab_transcribe/services/asr_service.py @@ -88,11 +88,11 @@ def __init__(self) -> None: device_index=device_index, ), "diarization": DiarizeService( - domain_type=settings.nemo_domain_type, - storage_path=settings.nemo_storage_path, - output_path=settings.nemo_output_path, device=self.device, device_index=device_index, + window_lengths=settings.window_lengths, + shift_lengths=settings.shift_lengths, + multiscale_weights=settings.multiscale_weights, ), "alignment": AlignService(self.device), "post_processing": PostProcessingService(), @@ -288,7 +288,9 @@ def process_diarization(self, task: dict, gpu_index: int) -> None: None: The task is updated with the result. """ try: - result = self.services["diarization"](task["input"], model_index=gpu_index) + result = self.services["diarization"]( + task["input"], model_index=gpu_index, vad_service=self.services["vad"] + ) except Exception as e: result = Exception(f"Error in diarization: {e}\n{traceback.format_exc()}") diff --git a/wordcab_transcribe/services/diarize_service.py b/wordcab_transcribe/services/diarize_service.py index 324a7dc..f97e933 100644 --- a/wordcab_transcribe/services/diarize_service.py +++ b/wordcab_transcribe/services/diarize_service.py @@ -13,24 +13,362 @@ # limitations under the License. """Diarization Service for audio files.""" -from pathlib import Path -from typing import List, NamedTuple, Union +import math +from typing import Dict, List, NamedTuple, Tuple, Union -import librosa -import soundfile as sf import torch -from nemo.collections.asr.models.msdd_models import NeuralDiarizer +from nemo.collections.asr.models import EncDecSpeakerLabelModel +from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering +from torch.cuda.amp import autocast +from torch.utils.data import Dataset from wordcab_transcribe.logging import time_and_tell -from wordcab_transcribe.utils import load_nemo_config +from wordcab_transcribe.services.vad_service import VadService -class NemoModel(NamedTuple): - """NeMo Model.""" +class MultiscaleEmbeddingsAndTimestamps(NamedTuple): + """Multiscale embeddings and timestamps outputs of the SegmentationModule.""" - model: NeuralDiarizer - output_path: str - tmp_audio_path: str + embeddings: torch.Tensor + timestamps: torch.Tensor + multiscale_segment_counts: torch.Tensor + multiscale_weights: torch.Tensor + + +class AudioSegmentDataset(Dataset): + """Dataset for audio segments used by the SegmentationModule.""" + + def __init__( + self, waveform: torch.Tensor, segments: List[dict], sample_rate=16000 + ) -> None: + """ + Initialize the dataset for the SegmentationModule. + + Args: + waveform (torch.Tensor): Waveform of the audio file. + segments (List[dict]): List of segments with the following keys: "offset", "duration". + sample_rate (int): Sample rate of the audio file. Defaults to 16000. + """ + self.waveform = waveform + self.segments = segments + self.sample_rate = sample_rate + + def __len__(self) -> int: + """Get the length of the dataset.""" + return len(self.segments) + + def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get an item from the dataset. + + Args: + idx (int): Index of the item to get. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of the audio segment and its length. + """ + segment_info = self.segments[idx] + offset_samples = int(segment_info["offset"] * self.sample_rate) + duration_samples = int(segment_info["duration"] * self.sample_rate) + + segment = self.waveform[offset_samples : offset_samples + duration_samples] + + return segment, torch.tensor(segment.shape[0]).long() + + +def segmentation_collate_fn( + batch: List[Tuple[torch.Tensor, torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Collate function used by the dataloader of the SegmentationModule. + + Args: + batch (List[Tuple[torch.Tensor, torch.Tensor]]): List of audio segments and their lengths. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of the audio segments and their lengths. + """ + _, audio_lengths = zip(*batch) + + if not audio_lengths[0]: + return None, None + + fixed_length = int(max(audio_lengths)) + + audio_signal, new_audio_lengths = [], [] + for sig, sig_len in batch: + sig_len = sig_len.item() + chunck_len = sig_len - fixed_length + + if chunck_len < 0: + repeat = fixed_length // sig_len + rem = fixed_length % sig_len + sub = sig[-rem:] if rem > 0 else torch.tensor([]) + rep_sig = torch.cat(repeat * [sig]) + sig = torch.cat((rep_sig, sub)) + new_audio_lengths.append(torch.tensor(fixed_length)) + + audio_signal.append(sig) + + audio_signal = torch.stack(audio_signal) + audio_lengths = torch.stack(new_audio_lengths) + + return audio_signal, audio_lengths + + +class SegmentationModule: + """Segmentation module for diariation.""" + + def __init__(self, device: str, multiscale_weights: List[float]) -> None: + """ + Initialize the segmentation module. + + Args: + device (str): Device to use for inference. Can be "cpu" or "cuda". + multiscale_weights (List[float]): List of weights for each scale. + """ + self.multiscale_weights = torch.tensor(multiscale_weights).unsqueeze(0).float() + + if len(multiscale_weights) > 3: + self.batch_size = 64 + elif len(multiscale_weights) > 1: + self.batch_size = 128 + else: + self.batch_size = 256 + + self.speaker_model = EncDecSpeakerLabelModel.from_pretrained( + model_name="titanet_large", map_location=None + ).to(device) + self.speaker_model.eval() + + def __call__( + self, + waveform: torch.Tensor, + vad_outputs: List[dict], + scale_dict: Dict[int, Tuple[float, float]], + ) -> MultiscaleEmbeddingsAndTimestamps: + """ + Run the segmentation module. + + Args: + waveform (torch.Tensor): Waveform of the audio file. + vad_outputs (List[dict]): List of segments with the following keys: "start", "end". + scale_dict (Dict[int, Tuple[float, float]]): Dictionary of scales in the format {scale_id: (window, shift)}. + + Returns: + MultiscaleEmbeddingsAndTimestamps: Embeddings and timestamps of the audio file. + + Raises: + ValueError: If there is a mismatch of counts between embedding vectors and timestamps. + """ + embeddings, timestamps, segment_indexes = [], [], [] + + for _, (window, shift) in scale_dict.items(): + scale_segments = self.get_audio_segments_from_scale( + vad_outputs, window, shift + ) + + _embeddings, _timestamps = self.extract_embeddings(waveform, scale_segments) + + if len(_embeddings) != len(_timestamps): + raise ValueError( + "Mismatch of counts between embedding vectors and timestamps" + ) + + embeddings.append(_embeddings) + segment_indexes.append(_embeddings.shape[0]) + timestamps.append(torch.tensor(_timestamps)) + + return MultiscaleEmbeddingsAndTimestamps( + embeddings=torch.cat(embeddings, dim=0), + timestamps=torch.cat(timestamps, dim=0), + multiscale_segment_counts=torch.tensor(segment_indexes), + multiscale_weights=self.multiscale_weights, + ) + + def get_audio_segments_from_scale( + self, + vad_outputs: List[dict], + window: float, + shift: float, + min_subsegment_duration: float = 0.05, + ) -> List[dict]: + """ + Return a list of audio segments based on the VAD outputs and the scale window and shift length. + + Args: + vad_outputs (List[dict]): List of segments with the following keys: "start", "end". + window (float): Window length. Used to get subsegments. + shift (float): Shift length. Used to get subsegments. + min_subsegment_duration (float): Minimum duration of a subsegment in seconds. + + Returns: + List[dict]: List of audio segments with the following keys: "offset", "duration". + """ + scale_segment = [] + for segment in vad_outputs: + segment_start, segment_end = ( + segment["start"] / 16000, + segment["end"] / 16000, + ) + subsegments = self.get_subsegments( + segment_start, segment_end, window, shift + ) + + for subsegment in subsegments: + start, duration = subsegment + if duration > min_subsegment_duration: + scale_segment.append({"offset": start, "duration": duration}) + + return scale_segment + + def extract_embeddings( + self, waveform: torch.Tensor, scale_segments: List[dict] + ) -> Tuple[torch.Tensor, List[List[float]]]: + """ + This method extracts speaker embeddings from the audio file based on the scale segments. + + Args: + waveform (torch.Tensor): Waveform of the audio file. + scale_segments (List[dict]): List of segments with the following keys: "offset", "duration". + + Returns: + Tuple[torch.Tensor, List[List[float]]]: Tuple of embeddings and timestamps. + """ + all_embs = torch.empty([0]) + + dataset = AudioSegmentDataset(waveform, scale_segments) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=False, + collate_fn=segmentation_collate_fn, + ) + + for batch in dataloader: + _batch = [x.to(self.speaker_model.device) for x in batch] + audio_signal, audio_signal_len = _batch + + with autocast(): + _, embeddings = self.speaker_model.forward( + input_signal=audio_signal, input_signal_length=audio_signal_len + ) + embeddings = embeddings.view(-1, embeddings.shape[-1]) + all_embs = torch.cat((all_embs, embeddings.cpu().detach()), dim=0) + + del _batch, audio_signal, audio_signal_len, embeddings + + embeddings, time_stamps = [], [] + for i, segment in enumerate(scale_segments): + if i == 0: + embeddings = all_embs[i].view(1, -1) + else: + embeddings = torch.cat((embeddings, all_embs[i].view(1, -1))) + + time_stamps.append([segment["offset"], segment["duration"]]) + + return embeddings, time_stamps + + @staticmethod + def get_subsegments( + segment_start: float, segment_end: float, window: float, shift: float + ) -> List[List[float]]: + """ + Return a list of subsegments based on the segment start and end time and the window and shift length. + + Args: + segment_start (float): Segment start time. + segment_end (float): Segment end time. + window (float): Window length. + shift (float): Shift length. + + Returns: + List[List[float]]: List of subsegments with start time and duration. + """ + start = segment_start + duration = segment_end - segment_start + base = math.ceil((duration - window) / shift) + + subsegments: List[List[float]] = [] + slices = 1 if base < 0 else base + 1 + for slice_id in range(slices): + end = start + window + + if end > segment_end: + end = segment_end + + subsegments.append([start, end - start]) + + start = segment_start + (slice_id + 1) * shift + + return subsegments + + +class ClusteringModule: + """Clustering module for diariation.""" + + def __init__(self, device: str, max_num_speakers: int = 8) -> None: + """Initialize the clustering module.""" + self.params = dict( + oracle_num_speakers=False, + max_num_speakers=max_num_speakers, + enhanced_count_thres=80, + max_rp_threshold=0.25, + sparse_search_volume=30, + maj_vote_spk_count=False, + ) + self.clustering_model = SpeakerClustering(parallelism=True, cuda=True) + self.clustering_model.device = device + + def __call__( + self, ms_emb_ts: MultiscaleEmbeddingsAndTimestamps + ) -> List[Tuple[float, float, int]]: + """ + Run the clustering module and return the speaker segments. + + Args: + ms_emb_ts (MultiscaleEmbeddingsAndTimestamps): Embeddings and timestamps of the audio file in multiscale. + The multiscale embeddings and timestamps are from the SegmentationModule. + + Returns: + List[Tuple[float, float, int]]: List of segments with the following keys: "start", "end", "speaker". + """ + base_scale_idx = ms_emb_ts.multiscale_segment_counts.shape[0] - 1 + cluster_labels = self.clustering_model.forward_infer( + embeddings_in_scales=ms_emb_ts.embeddings, + timestamps_in_scales=ms_emb_ts.timestamps, + multiscale_segment_counts=ms_emb_ts.multiscale_segment_counts, + multiscale_weights=ms_emb_ts.multiscale_weights, + oracle_num_speakers=-1, + max_num_speakers=self.params["max_num_speakers"], + max_rp_threshold=self.params["max_rp_threshold"], + sparse_search_volume=self.params["sparse_search_volume"], + ) + + del ms_emb_ts + torch.cuda.empty_cache() + + timestamps = self.clustering_model.timestamps_in_scales[base_scale_idx] + cluster_labels = cluster_labels.cpu().numpy() + + if len(cluster_labels) != timestamps.shape[0]: + raise ValueError( + "Mismatch of length between cluster_labels and timestamps." + ) + + clustering_labels = [] + for idx, label in enumerate(cluster_labels): + start, end = timestamps[idx] + clustering_labels.append((float(start), float(start + end), int(label))) + + return clustering_labels + + +class DiarizationModels(NamedTuple): + """Diarization Models.""" + + segmentation: SegmentationModule + clustering: ClusteringModule device: str @@ -39,48 +377,56 @@ class DiarizeService: def __init__( self, - domain_type: str, - storage_path: str, - output_path: str, device: str, device_index: List[int], + window_lengths: List[float], + shift_lengths: List[float], + multiscale_weights: List[int], + max_num_speakers: int = 8, ) -> None: """Initialize the Diarize Service. - This service uses the NeuralDiarizer from NeMo to diarize audio files. + This service uses the NVIDIA NeMo diarization models. Args: - domain_type (str): Domain type to use for diarization. Can be "general", "telephonic" or "meeting". - storage_path (str): Path where the diarization pipeline will save temporary files. - output_path (str): Path where the diarization pipeline will save the final output files. device (str): Device to use for inference. Can be "cpu" or "cuda". device_index (Union[int, List[int]]): Index of the device to use for inference. + window_lengths (List[float]): List of window lengths. + shift_lengths (List[float]): List of shift lengths. + multiscale_weights (List[int]): List of weights for each scale. + max_num_speakers (int): Maximum number of speakers. Defaults to 8. """ self.device = device self.models = {} - for idx in device_index: - _output_path = Path(output_path) / f"output_{idx}" + # Multi-scale segmentation diarization + self.max_num_speakers = max_num_speakers + self.window_lengths = window_lengths + self.shift_lengths = shift_lengths + self.multiscale_weights = multiscale_weights + self.scale_dict = { + k: (w, s) for k, (w, s) in enumerate(zip(window_lengths, shift_lengths)) + } + + for idx in device_index: _device = f"cuda:{idx}" if self.device == "cuda" else "cpu" - cfg, tmp_audio_path = load_nemo_config( - domain_type=domain_type, - storage_path=storage_path, - output_path=_output_path, - device=_device, - index=idx, - ) - model = NeuralDiarizer(cfg=cfg).to(_device) - self.models[idx] = NemoModel( - model=model, - output_path=_output_path, - tmp_audio_path=tmp_audio_path, + + segmentation_module = SegmentationModule(_device, self.multiscale_weights) + clustering_module = ClusteringModule(_device, self.max_num_speakers) + + self.models[idx] = DiarizationModels( + segmentation=segmentation_module, + clustering=clustering_module, device=_device, ) @time_and_tell def __call__( - self, filepath: Union[str, torch.Tensor], model_index: int + self, + filepath: Union[str, torch.Tensor], + model_index: int, + vad_service: VadService, ) -> List[dict]: """ Run inference with the diarization model. @@ -88,45 +434,82 @@ def __call__( Args: filepath (Union[str, torch.Tensor]): Path to the audio file or waveform. model_index (int): Index of the model to use for inference. + vad_service (VadService): VAD service instance to use for Voice Activity Detection. Returns: List[dict]: List of segments with the following keys: "start", "end", "speaker". """ - if isinstance(filepath, str): - waveform, sample_rate = librosa.load(filepath, sr=None) - else: - waveform = filepath - sample_rate = 16000 + vad_outputs, _ = vad_service(filepath, False) - sf.write( - self.models[model_index].tmp_audio_path, waveform, sample_rate, "PCM_16" + ms_emb_ts: MultiscaleEmbeddingsAndTimestamps = self.models[ + model_index + ].segmentation( + waveform=filepath, + vad_outputs=vad_outputs, + scale_dict=self.scale_dict, ) - self.models[model_index].model.diarize() + clustering_outputs = self.models[model_index].clustering(ms_emb_ts) - outputs = self._format_timestamps(self.models[model_index].output_path) + _outputs = self.get_contiguous_stamps(clustering_outputs) + outputs = self.merge_stamps(_outputs) return outputs @staticmethod - def _format_timestamps(output_path: str) -> List[dict]: + def get_contiguous_stamps( + stamps: List[Tuple[float, float, int]] + ) -> List[Tuple[float, float, int]]: """ - Format timestamps from the diarization pipeline. + Return contiguous timestamps. Args: - output_path (str): Path where the diarization pipeline saved the final output files. + stamps (List[Tuple[float, float, int]]): List of segments containing the start time, end time and speaker. Returns: - List[dict]: List of segments with the following keys: "start", "end", "speaker". + List[Tuple[float, float, int]]: List of segments containing the start time, end time and speaker. """ - speaker_timestamps = [] + contiguous_stamps = [] + for i in range(len(stamps) - 1): + start, end, speaker = stamps[i] + next_start, next_end, next_speaker = stamps[i + 1] + + if end > next_start: + avg = (next_start + end) / 2.0 + stamps[i + 1] = (avg, next_end, next_speaker) + contiguous_stamps.append((start, avg, speaker)) + else: + contiguous_stamps.append((start, end, speaker)) + + start, end, speaker = stamps[-1] + contiguous_stamps.append((start, end, speaker)) + + return contiguous_stamps + + @staticmethod + def merge_stamps( + stamps: List[Tuple[float, float, int]] + ) -> List[Tuple[float, float, int]]: + """ + Merge timestamps of the same speaker. + + Args: + stamps (List[Tuple[float, float, int]]): List of segments containing the start time, end time and speaker. + + Returns: + List[Tuple[float, float, int]]: List of segments containing the start time, end time and speaker. + """ + overlap_stamps = [] + for i in range(len(stamps) - 1): + start, end, speaker = stamps[i] + next_start, next_end, next_speaker = stamps[i + 1] + + if end == next_start and speaker == next_speaker: + stamps[i + 1] = (start, next_end, next_speaker) + else: + overlap_stamps.append((start, end, speaker)) - with open(f"{output_path}/pred_rttms/mono_file.rttm") as f: - lines = f.readlines() - for line in lines: - line_list = line.split(" ") - s = int(float(line_list[5]) * 1000) - e = s + int(float(line_list[8]) * 1000) - speaker_timestamps.append([s, e, int(line_list[11].split("_")[-1])]) + start, end, speaker = stamps[-1] + overlap_stamps.append((start, end, speaker)) - return speaker_timestamps + return overlap_stamps diff --git a/wordcab_transcribe/services/post_processing_service.py b/wordcab_transcribe/services/post_processing_service.py index 1121faa..41f4b19 100644 --- a/wordcab_transcribe/services/post_processing_service.py +++ b/wordcab_transcribe/services/post_processing_service.py @@ -15,13 +15,7 @@ from typing import Any, Dict, List -from wordcab_transcribe.utils import ( - _convert_ms_to_s, - _convert_s_to_ms, - convert_timestamp, - format_punct, - is_empty_string, -) +from wordcab_transcribe.utils import convert_timestamp, format_punct, is_empty_string class PostProcessingService: @@ -106,12 +100,12 @@ def segments_speaker_mapping( while segment_index < len(transcript_segments): segment = transcript_segments[segment_index] segment_start, segment_end, segment_text = ( - _convert_s_to_ms(segment["start"]), - _convert_s_to_ms(segment["end"]), + segment["start"], + segment["end"], segment["text"], ) - while segment_start > float(end) or abs(segment_start - float(end)) < 300: + while segment_start > float(end) or abs(segment_start - float(end)) < 0.3: turn_idx += 1 turn_idx = min(turn_idx, len(speaker_timestamps) - 1) _, end, speaker = speaker_timestamps[turn_idx] @@ -126,8 +120,8 @@ def segments_speaker_mapping( ( i for i, word in enumerate(words) - if _convert_s_to_ms(word["start"]) > float(end) - or abs(_convert_s_to_ms(word["start"]) - float(end)) < 300 + if word["start"] > float(end) + or abs(word["start"] - float(end)) < 0.3 ), None, ) @@ -158,7 +152,7 @@ def segments_speaker_mapping( segment_index + 1, dict( start=words[word_index]["start"], - end=_convert_ms_to_s(segment_end), + end=segment_end, text=" ".join(_splitted_segment[word_index:]), words=words[word_index:], ), @@ -166,8 +160,8 @@ def segments_speaker_mapping( else: segment_speaker_mapping.append( dict( - start=_convert_ms_to_s(segment_start), - end=_convert_ms_to_s(segment_end), + start=segment_start, + end=segment_end, text=segment_text, speaker=speaker, words=words, @@ -176,8 +170,8 @@ def segments_speaker_mapping( else: segment_speaker_mapping.append( dict( - start=_convert_ms_to_s(segment_start), - end=_convert_ms_to_s(segment_end), + start=segment_start, + end=segment_end, text=segment_text, speaker=speaker, words=segment["words"], diff --git a/wordcab_transcribe/services/transcribe_service.py b/wordcab_transcribe/services/transcribe_service.py index 23c4109..e1d9518 100644 --- a/wordcab_transcribe/services/transcribe_service.py +++ b/wordcab_transcribe/services/transcribe_service.py @@ -399,10 +399,10 @@ def __call__( # if not use_batch and not isinstance(audio, tuple): if ( - vocab is not None - and isinstance(vocab, list) - and len(vocab) > 0 - and vocab[0].strip() + vocab is not None + and isinstance(vocab, list) + and len(vocab) > 0 + and vocab[0].strip() ): words = ", ".join(vocab) prompt = f"Vocab: {words.strip()}" diff --git a/wordcab_transcribe/utils.py b/wordcab_transcribe/utils.py index 57f6ab8..693cc58 100644 --- a/wordcab_transcribe/utils.py +++ b/wordcab_transcribe/utils.py @@ -13,7 +13,6 @@ # limitations under the License. """Utils module of the Wordcab Transcribe.""" import asyncio -import json import math import re import subprocess # noqa: S404 @@ -30,7 +29,6 @@ from fastapi import UploadFile from loguru import logger from num2words import num2words -from omegaconf import DictConfig, ListConfig, OmegaConf from yt_dlp import YoutubeDL @@ -564,71 +562,6 @@ def interpolate_nans(x: pd.Series, method="nearest") -> pd.Series: return x.ffill().bfill() -def load_nemo_config( - domain_type: str, - storage_path: str, - output_path: Path, - device: str, - index: int, -) -> Union[DictConfig, ListConfig]: - """ - Load NeMo config file based on a domain type. - - Args: - domain_type (str): The domain type. Can be "general", "meeting" or "telephonic". - storage_path (str): The path to the NeMo storage directory. - output_path (Path): The path to the NeMo output directory. - device (str): The device to use for inference. - index (int): The index of the model to use for inference. Used to create separate folders and files. - - Returns: - Tuple[DictConfig, str]: The NeMo config loaded as a DictConfig and the audio filepath. - """ - cfg_path = ( - Path(__file__).parent.parent - / "config" - / "nemo" - / f"diar_infer_{domain_type}.yaml" - ) - with open(cfg_path) as f: - cfg = OmegaConf.load(f) - - _storage_path = Path(__file__).parent.parent / storage_path - if not _storage_path.exists(): - _storage_path.mkdir(parents=True, exist_ok=True) - - temp_folder = Path.cwd() / f"temp_outputs_{index}" - if not temp_folder.exists(): - temp_folder.mkdir(parents=True, exist_ok=True) - - audio_filepath = str(temp_folder / "mono_file.wav") - meta = { - "audio_filepath": audio_filepath, - "offset": 0, - "duration": None, - "label": "infer", - "text": "-", - "rttm_filepath": None, - "uem_filepath": None, - } - - manifest_path = _storage_path / f"infer_manifest_{index}.json" - with open(manifest_path, "w") as fp: - json.dump(meta, fp) - fp.write("\n") - - _output_path = Path(__file__).parent.parent / output_path - if not _output_path.exists(): - _output_path.mkdir(parents=True, exist_ok=True) - - cfg.num_workers = 0 - cfg.device = device - cfg.diarizer.manifest_filepath = str(manifest_path) - cfg.diarizer.out_dir = str(_output_path) - - return cfg, audio_filepath - - def read_audio(filepath: str, sample_rate: int = 16000) -> Tuple[torch.Tensor, float]: """ Read an audio file and return the audio tensor.