Skip to content

Commit

Permalink
Enabling longform diarization, skeleton for punct-based alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleks committed Mar 31, 2024
1 parent 60e8ca7 commit 4422e79
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 3 deletions.
7 changes: 7 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ SHIFT_LENGTHS="1.0,0.75,0.625,0.5,0.25"
# "1.0, 1.0, 1.0, 1.0, 1.0".
MULTISCALE_WEIGHTS="1.0,1.0,1.0,1.0,1.0"
#
# --------------------------------------------------- POST-PROCESSING------------------------------------------------- #
#
# This parameter is used to control the punctuation-based alignment. If set to True, the predicted punctuation
# will be used to adjust speaker diarization. The default value is True, but note this comes with a performance
# tradeoff.
ENABLE_PUNCTUATION_BASED_ALIGNMENT=False
#
# ---------------------------------------------- ASR TYPE CONFIGURATION ---------------------------------------------- #
#
# The asr_type parameter is used to control the type of ASR used. The available options are: `async` or `live`.
Expand Down
6 changes: 6 additions & 0 deletions src/wordcab_transcribe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ class Settings:
window_lengths: List[float]
shift_lengths: List[float]
multiscale_weights: List[float]
# Post-processing
enable_punctuation_based_alignment: bool
# ASR type configuration
asr_type: Literal["async", "live", "only_transcription", "only_diarization"]
# Endpoint configuration
Expand Down Expand Up @@ -325,6 +327,10 @@ def __post_init__(self):
window_lengths=window_lengths,
shift_lengths=shift_lengths,
multiscale_weights=multiscale_weights,
# Post-processing
enable_punctuation_based_alignment=getenv(
"ENABLE_PUNCTUATION_BASED_ALIGNMENT", True
),
# ASR type
asr_type=getenv("ASR_TYPE", "async"),
# Endpoints configuration
Expand Down
11 changes: 11 additions & 0 deletions src/wordcab_transcribe/services/asr_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,17 @@ def process_post_processing(self, task: ASRTask) -> None:
else:
utterances = formatted_segments

if settings.enable_punctuation_based_alignment:
utterances, process_time = time_and_tell(
self.local_services.post_processing.punctuation_based_alignment(
utterances=utterances,
speaker_timestamps=task.diarization.result,
),
func_name="punctuation_based_alignment",
debug_mode=self.debug_mode,
)
total_post_process_time += process_time

final_utterances, process_time = time_and_tell(
self.local_services.post_processing.final_processing_before_returning(
utterances=utterances,
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": "???", "out_dir": "???", "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.63, "shift_length_in_sec": 0.08, "smoothing": false, "overlap": 0.5, "onset": 0.5, "offset": 0.3, "pad_onset": 0.2, "pad_offset": 0.2, "min_duration_on": 0.5, "min_duration_off": 0.5, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [1.9, 1.2, 0.5], "shift_length_in_sec": [0.95, 0.6, 0.25], "multiscale_weights": [1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 10, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": null, "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": null, "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}}
{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": null, "out_dir": null, "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.63, "shift_length_in_sec": 0.08, "smoothing": false, "overlap": 0.5, "onset": 0.5, "offset": 0.3, "pad_onset": 0.2, "pad_offset": 0.2, "min_duration_on": 0.5, "min_duration_off": 0.5, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [1.9, 1.2, 0.5], "shift_length_in_sec": [0.95, 0.6, 0.25], "multiscale_weights": [1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 10, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": null, "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": null, "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": "???", "out_dir": "???", "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.63, "shift_length_in_sec": 0.01, "smoothing": false, "overlap": 0.5, "onset": 0.9, "offset": 0.5, "pad_onset": 0, "pad_offset": 0, "min_duration_on": 0, "min_duration_off": 0.6, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [3.0, 2.5, 2.0, 1.5, 1.0, 0.5], "shift_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5, 0.25], "multiscale_weights": [1, 1, 1, 1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 30, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": null, "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": "stt_en_conformer_ctc_large", "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}}
{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": null, "out_dir": null, "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.63, "shift_length_in_sec": 0.01, "smoothing": false, "overlap": 0.5, "onset": 0.9, "offset": 0.5, "pad_onset": 0, "pad_offset": 0, "min_duration_on": 0, "min_duration_off": 0.6, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [3.0, 2.5, 2.0, 1.5, 1.0, 0.5], "shift_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5, 0.25], "multiscale_weights": [1, 1, 1, 1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 30, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": null, "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": "stt_en_conformer_ctc_large", "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": "???", "out_dir": "???", "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.15, "shift_length_in_sec": 0.01, "smoothing": "median", "overlap": 0.5, "onset": 0.1, "offset": 0.1, "pad_onset": 0.1, "pad_offset": 0, "min_duration_on": 0, "min_duration_off": 0.2, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5], "shift_length_in_sec": [0.75, 0.625, 0.5, 0.375, 0.25], "multiscale_weights": [1, 1, 1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 30, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": "diar_msdd_telephonic", "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": "stt_en_conformer_ctc_large", "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}}
{"name": "ClusterDiarizer", "num_workers": 1, "sample_rate": 16000, "batch_size": 64, "device": null, "verbose": true, "diarizer": {"manifest_filepath": null, "out_dir": null, "oracle_vad": false, "collar": 0.25, "ignore_overlap": true, "vad": {"model_path": "vad_multilingual_marblenet", "external_vad_manifest": null, "parameters": {"window_length_in_sec": 0.15, "shift_length_in_sec": 0.01, "smoothing": "median", "overlap": 0.5, "onset": 0.1, "offset": 0.1, "pad_onset": 0.1, "pad_offset": 0, "min_duration_on": 0, "min_duration_off": 0.2, "filter_speech_first": true}}, "speaker_embeddings": {"model_path": "titanet_large", "parameters": {"window_length_in_sec": [1.5, 1.25, 1.0, 0.75, 0.5], "shift_length_in_sec": [0.75, 0.625, 0.5, 0.375, 0.25], "multiscale_weights": [1, 1, 1, 1, 1], "save_embeddings": true}}, "clustering": {"parameters": {"oracle_num_speakers": false, "max_num_speakers": 8, "enhanced_count_thres": 80, "max_rp_threshold": 0.25, "sparse_search_volume": 30, "maj_vote_spk_count": false, "chunk_cluster_count": 50, "embeddings_per_chunk": 10000}}, "msdd_model": {"model_path": "diar_msdd_telephonic", "parameters": {"use_speaker_model_from_ckpt": true, "infer_batch_size": 25, "sigmoid_threshold": [0.7], "seq_eval_mode": false, "split_infer": true, "diar_window_length": 50, "overlap_infer_spk_limit": 5}}, "asr": {"model_path": "stt_en_conformer_ctc_large", "parameters": {"asr_based_vad": false, "asr_based_vad_threshold": 1.0, "asr_batch_size": null, "decoder_delay_in_sec": null, "word_ts_anchor_offset": null, "word_ts_anchor_pos": "start", "fix_word_ts_with_VAD": false, "colored_text": false, "print_time": true, "break_lines": false}, "ctc_decoder_parameters": {"pretrained_language_model": null, "beam_width": 32, "alpha": 0.5, "beta": 2.5}, "realigning_lm_parameters": {"arpa_language_model": null, "min_number_of_words": 3, "max_number_of_words": 10, "logprob_diff_threshold": 1.2}}}}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# and limitations under the License.
"""Longform diarization Service for audio files."""

import json
import tempfile
from pathlib import Path
from typing import List, Optional, Tuple, Union
Expand All @@ -27,6 +28,7 @@
import torch
import torchaudio
from nemo.collections.asr.models import NeuralDiarizer
from omegaconf import OmegaConf
from tensorshare import Backend, TensorShare

from wordcab_transcribe.models import DiarizationOutput
Expand All @@ -43,6 +45,7 @@ class LongFormDiarizeService:
def __init__(
self,
device: str,
domain: str = "telephonic",
) -> None:
"""
Initialize the DiarizeService.
Expand All @@ -53,9 +56,15 @@ def __init__(
Returns:
None
"""
config_path = Path(__file__).parent / "configs" / f"{domain}.json"
with open(config_path, "r") as f:
general_conf = json.load(f)

config = OmegaConf.create(general_conf)
self.diarization_model = NeuralDiarizer.from_pretrained(
model_name="diar_msdd_telephonic"
).to(device)
self.diarization_model._cfg = config
# TODO: Ability to modify config

def __call__(
Expand Down
38 changes: 38 additions & 0 deletions src/wordcab_transcribe/services/post_processing_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from typing import List, Tuple, Union

from wordcab_transcribe.config import settings
from wordcab_transcribe.models import (
DiarizationOutput,
DiarizationSegment,
Expand All @@ -40,6 +41,13 @@ def __init__(self) -> None:
"""Initialize the PostProcessingService."""
self.sample_rate = 16000

if settings.enable_punctuation_based_alignment:
from deepmultilingualpunctuation import PunctuationModel

self.punct_model = PunctuationModel(model="kredor/punctuate-all")
else:
self.punct_model = None

def single_channel_speaker_mapping(
self,
transcript_segments: List[Utterance],
Expand Down Expand Up @@ -379,6 +387,36 @@ def reconstruct_multi_channel_utterances(

return [Utterance(**sentence) for sentence in sentences]

def punctuation_based_alignment(
self,
utterances: List[Utterance],
speaker_timestamps: DiarizationOutput,
):
pass
# word_list = []
# for utterance in utterances:
# for word in utterance.words:
# word_list.append(word.word)
#
# labled_words = self.punct_model.predict(word_list)
#
# ending_puncts = ".?!"
# model_puncts = ".,;:!?"
#
# def is_acronym(w):
# return re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", w)
#
# for ix, (word, labeled_tuple) in enumerate(zip(word_list, labled_words)):
# if (
# word
# and labeled_tuple[1] in ending_puncts
# and (word[-1] not in model_puncts or is_acronym(word))
# ):
# word += labeled_tuple[1]
# if word.endswith(".."):
# word = word.rstrip(".")
# word_dict["word"] = word

def final_processing_before_returning(
self,
utterances: List[Utterance],
Expand Down

0 comments on commit 4422e79

Please sign in to comment.