From 58ded010cd93c9c26cf7d997283faf795628157d Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Wed, 28 Aug 2024 13:14:37 -0700 Subject: [PATCH 01/16] Checkpoint for whisper and speechbrain transcription --- docs/source/changelog/changelog_3.0.rst | 8 + environment.yml | 3 + montreal_forced_aligner/abc.py | 2 +- .../pronunciation_probabilities.py | 33 +- .../acoustic_modeling/trainer.py | 4 +- .../acoustic_modeling/triphone.py | 23 - montreal_forced_aligner/alignment/base.py | 8 +- .../alignment/multiprocessing.py | 82 +- .../command_line/anchor.py | 3 +- montreal_forced_aligner/command_line/g2p.py | 29 +- montreal_forced_aligner/command_line/mfa.py | 8 +- .../command_line/transcribe.py | 203 +++- .../command_line/validate.py | 8 +- montreal_forced_aligner/config.py | 2 + montreal_forced_aligner/corpus/base.py | 14 +- montreal_forced_aligner/corpus/features.py | 6 +- .../corpus/multiprocessing.py | 38 +- montreal_forced_aligner/data.py | 116 ++- montreal_forced_aligner/db.py | 17 +- .../diarization/multiprocessing.py | 23 +- .../diarization/speaker_diarizer.py | 30 +- .../dictionary/multispeaker.py | 313 +++--- montreal_forced_aligner/exceptions.py | 10 +- montreal_forced_aligner/g2p/generator.py | 202 ++-- montreal_forced_aligner/g2p/mixins.py | 31 +- .../g2p/phonetisaurus_trainer.py | 10 +- .../language_modeling/multiprocessing.py | 15 +- montreal_forced_aligner/models.py | 13 +- .../online/transcription.py | 52 + .../tokenization/english.py | 58 +- .../tokenization/simple.py | 13 +- .../transcription/multiprocessing.py | 471 ++++++++- .../transcription/transcriber.py | 904 ++++++++++++------ montreal_forced_aligner/utils.py | 18 +- .../vad/multiprocessing.py | 111 ++- montreal_forced_aligner/vad/segmenter.py | 15 + .../validation/corpus_validator.py | 148 ++- tests/test_commandline_train.py | 1 + tests/test_commandline_transcribe.py | 77 ++ tests/test_g2p.py | 4 +- 40 files changed, 2236 insertions(+), 890 deletions(-) diff --git a/docs/source/changelog/changelog_3.0.rst b/docs/source/changelog/changelog_3.0.rst index 45fc1462..5a789434 100644 --- a/docs/source/changelog/changelog_3.0.rst +++ b/docs/source/changelog/changelog_3.0.rst @@ -5,6 +5,14 @@ 3.0 Changelog ************* +3.1.4 +----- + +- Optimized :code:`mfa g2p` to better use multiple processes +- Added :code:`--export_scores` to :code:`mfa g2p` for adding a column representing the final weights of the generated pronunciations +- Added :code:`--output_directory` to :code:`mfa validate` to save generated validation files rather than the temporary directory +- Fixed a bug in cutoff modeling that was preventing them from being properly parsed + 3.1.3 ----- diff --git a/environment.yml b/environment.yml index 4b2313e2..e8256b47 100644 --- a/environment.yml +++ b/environment.yml @@ -55,6 +55,9 @@ dependencies: - build - twine - speechbrain + - kenlm + - pygtrie + - faster-whisper - python-mecab-ko - jamo - pythainlp diff --git a/montreal_forced_aligner/abc.py b/montreal_forced_aligner/abc.py index 8f2c4f77..d7f345fa 100644 --- a/montreal_forced_aligner/abc.py +++ b/montreal_forced_aligner/abc.py @@ -818,7 +818,7 @@ def setup_logger(self) -> None: f"You are currently running an older version of MFA ({current_version}) than the latest available ({latest_version}). " f"To update, please run mfa_update." ) - except KeyError: + except Exception: pass if re.search(r"\d+\.\d+\.\d+a", current_version) is not None: logger.debug( diff --git a/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py b/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py index 0903b3cd..290c6408 100644 --- a/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py +++ b/montreal_forced_aligner/acoustic_modeling/pronunciation_probabilities.py @@ -352,9 +352,6 @@ def train_pronunciation_probabilities(self) -> None: return silence_prob_sum = 0 - initial_silence_prob_sum = 0 - final_silence_correction_sum = 0 - final_non_silence_correction_sum = 0 with self.worker.session() as session: dictionaries = session.query(Dictionary).all() @@ -396,25 +393,25 @@ def train_pronunciation_probabilities(self) -> None: data[k] = 0.5 if self.silence_probabilities: d.silence_probability = data["silence_probability"] - d.initial_silence_probability = data["initial_silence_probability"] - d.final_silence_correction = data["final_silence_correction"] - d.final_non_silence_correction = data["final_non_silence_correction"] + # d.initial_silence_probability = data["initial_silence_probability"] + # d.final_silence_correction = data["final_silence_correction"] + # d.final_non_silence_correction = data["final_non_silence_correction"] silence_prob_sum += d.silence_probability - initial_silence_prob_sum += d.initial_silence_probability - final_silence_correction_sum += d.final_silence_correction - final_non_silence_correction_sum += d.final_non_silence_correction + # initial_silence_prob_sum += d.initial_silence_probability + # final_silence_correction_sum += d.final_silence_correction + # final_non_silence_correction_sum += d.final_non_silence_correction if self.silence_probabilities: self.worker.silence_probability = silence_prob_sum / len(dictionaries) - self.worker.initial_silence_probability = initial_silence_prob_sum / len( - dictionaries - ) - self.worker.final_silence_correction = final_silence_correction_sum / len( - dictionaries - ) - self.worker.final_non_silence_correction = ( - final_non_silence_correction_sum / len(dictionaries) - ) + # self.worker.initial_silence_probability = initial_silence_prob_sum / len( + # dictionaries + # ) + # self.worker.final_silence_correction = final_silence_correction_sum / len( + # dictionaries + # ) + # self.worker.final_non_silence_correction = ( + # final_non_silence_correction_sum / len(dictionaries) + # ) session.commit() self.worker.write_lexicon_information() return diff --git a/montreal_forced_aligner/acoustic_modeling/trainer.py b/montreal_forced_aligner/acoustic_modeling/trainer.py index 3847981e..745bea45 100644 --- a/montreal_forced_aligner/acoustic_modeling/trainer.py +++ b/montreal_forced_aligner/acoustic_modeling/trainer.py @@ -183,6 +183,7 @@ def __init__( self.add_config(k, v) self.final_alignment = True self.model_version = model_version + self.boost_silence = 1.5 @classmethod def default_training_configurations(cls) -> List[Tuple[str, Dict[str, Any]]]: @@ -585,7 +586,6 @@ def train(self) -> None: not self.current_workflow.done or not self.current_workflow.working_directory.exists() ): - logger.debug(f"Skipping {self.current_aligner.identifier} alignments") self.align() with self.session() as session: session.query(WordInterval).delete() @@ -595,6 +595,8 @@ def train(self) -> None: self.analyze_alignments() if self.current_subset != 0: self.quality_check_subset() + else: + logger.debug(f"Skipping {self.current_aligner.identifier} alignments") self.set_current_workflow(trainer.identifier) if trainer.identifier.startswith("pronunciation_probabilities"): diff --git a/montreal_forced_aligner/acoustic_modeling/triphone.py b/montreal_forced_aligner/acoustic_modeling/triphone.py index 33c75fcf..2cda9759 100644 --- a/montreal_forced_aligner/acoustic_modeling/triphone.py +++ b/montreal_forced_aligner/acoustic_modeling/triphone.py @@ -405,29 +405,6 @@ def _setup_tree(self, init_from_previous=False, initial_mix_up=True) -> None: for q_set in questions: train_logger.debug(", ".join([self.reversed_phone_mapping[x] for x in q_set])) - # Remove questions containing silence and other phones - train_logger.debug("Filtering the following sets for containing silence phone:") - silence_phone_id = self.phone_mapping[self.optional_silence_phone] - silence_sets = [ - x for x in questions if silence_phone_id in x and x != [silence_phone_id] - ] - filtered = [] - existing_sets = {tuple(x) for x in questions} - for q_set in silence_sets: - train_logger.debug(", ".join([self.reversed_phone_mapping[x] for x in q_set])) - - for q_set in questions: - if silence_phone_id not in q_set or q_set == [silence_phone_id]: - filtered.append(q_set) - continue - q_set = [x for x in q_set if x != silence_phone_id] - if not q_set: - continue - if tuple(q_set) in existing_sets: - continue - filtered.append(q_set) - questions = filtered - extra_questions = self.worker.extra_questions_mapping if extra_questions: train_logger.debug(f"Adding {len(extra_questions)} questions") diff --git a/montreal_forced_aligner/alignment/base.py b/montreal_forced_aligner/alignment/base.py index 35fe5634..9596d24f 100644 --- a/montreal_forced_aligner/alignment/base.py +++ b/montreal_forced_aligner/alignment/base.py @@ -623,9 +623,9 @@ def compute_pronunciation_probabilities(self): { "id": d_id, "silence_probability": silence_probability, - "initial_silence_probability": initial_silence_probability, - "final_silence_correction": final_silence_correction, - "final_non_silence_correction": final_non_silence_correction, + # "initial_silence_probability": initial_silence_probability, + # "final_silence_correction": final_silence_correction, + # "final_non_silence_correction": final_non_silence_correction, } ) @@ -1328,8 +1328,6 @@ def export_files( Format to save alignments, one of 'long_textgrids' (the default), 'short_textgrids', or 'json', passed to praatio include_original_text: bool Flag for including the original text of the corpus files as a tier - workflow: :class:`~montreal_forced_aligner.data.WorkflowType` - Workflow to use when exporting files """ if isinstance(output_directory, str): output_directory = Path(output_directory) diff --git a/montreal_forced_aligner/alignment/multiprocessing.py b/montreal_forced_aligner/alignment/multiprocessing.py index b4fab296..06775ad7 100644 --- a/montreal_forced_aligner/alignment/multiprocessing.py +++ b/montreal_forced_aligner/alignment/multiprocessing.py @@ -7,6 +7,7 @@ import collections import json import logging +import math import multiprocessing as mp import os import shutil @@ -115,16 +116,14 @@ class GeneratePronunciationsArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run - text_int_paths: dict[int, Path] - Per dictionary text SCP paths - ali_paths: dict[int, Path] - Per dictionary alignment paths - model_path: :class:`~pathlib.Path` - Acoustic model path + aligner: :class:`kalpy.gmm.align.GmmAligner` + GmmAligner to use + lexicon_compilers: dict[int, :class:`kalpy.fstext.lexicon.LexiconCompiler`] + Lexicon compilers for each pronunciation dictionary for_g2p: bool Flag for training a G2P model with acoustic information """ @@ -143,12 +142,16 @@ class AlignmentExtractionArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run - model_path: :class:`~pathlib.Path` - Acoustic model path + working_directory: :class:`~pathlib.Path` + Working directory + lexicon_compilers: dict[int, :class:`kalpy.fstext.lexicon.LexiconCompiler`] + Lexicon compilers for each pronunciation dictionary + aligner: :class:`kalpy.gmm.align.GmmAligner` + GmmAligner to use frame_shift: float Frame shift in seconds ali_paths: dict[int, Path] @@ -180,8 +183,8 @@ class ExportTextGridArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run export_frame_shift: float @@ -215,10 +218,12 @@ class CompileTrainGraphsArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Working directory tree_path: :class:`~pathlib.Path` Path to tree file model_path: :class:`~pathlib.Path` @@ -243,10 +248,12 @@ class AlignArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Working directory model_path: :class:`~pathlib.Path` Path to model file align_options: dict[str, Any] @@ -271,8 +278,8 @@ class AnalyzeAlignmentsArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run model_path: :class:`~pathlib.Path` @@ -294,8 +301,8 @@ class FineTuneArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run tree_path: :class:`~pathlib.Path` @@ -335,10 +342,12 @@ class PhoneConfidenceArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Working directory model_path: :class:`~pathlib.Path` Path to model file phone_pdf_counts_path: :class:`~pathlib.Path` @@ -359,8 +368,8 @@ class AccStatsArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run working_directory: :class:`~pathlib.Path` @@ -424,9 +433,12 @@ def _run(self): session.query(Word).filter(Word.word_type == WordType.interjection).all() ) if interjection_words: - max_count = max(x.count for x in interjection_words) + max_count = max(math.log(x.count) for x in interjection_words) for w in interjection_words: - cost = max_count / w.count + count = math.log(w.count) + if count == 0: + count = 0.01 + cost = max_count / count interjection_costs[w.word] = cost if self.use_g2p: text_column = Utterance.normalized_character_text @@ -1366,15 +1378,15 @@ def _run(self) -> None: intervals = transcription.generate_ctm( self.transition_model, lexicon_compiler.phone_table, self.frame_shift ) - utterance = int(transcription.utterance_id.split("-")[-1]) + utterance_id = int(transcription.utterance_id.split("-")[-1]) try: ctm = lexicon_compiler.phones_to_pronunciations( transcription.words, intervals, transcription=True, - text=utterance_texts.get(utterance, None), + text=utterance_texts.get(utterance_id, None), ) - ctm.update_utterance_boundaries(*utterance_times[utterance]) + ctm.update_utterance_boundaries(*utterance_times[utterance_id]) except Exception: exc_type, exc_value, exc_traceback = sys.exc_info() utterance, sound_file_path, text_file_path = ( @@ -1384,10 +1396,12 @@ def _run(self) -> None: .join(Utterance.file) .join(File.sound_file) .join(File.text_file) - .filter(Utterance.id == utterance) + .filter(Utterance.id == utterance_id) .first() ) - extraction_logger.debug(f"Error processing {utterance}:") + extraction_logger.debug( + f"Error processing {utterance} ({utterance_id}):" + ) extraction_logger.debug( f"Utterance information: {sound_file_path}, {text_file_path}, {utterance.begin} - {utterance.end}" ) @@ -1403,7 +1417,7 @@ def _run(self) -> None: traceback_lines, self.log_path, ) - self.callback((utterance, d.id, ctm)) + self.callback((utterance_id, d.id, ctm)) else: ali_path = job.construct_path(workflow.working_directory, "ali", "ark", d.id) if not ali_path.exists(): diff --git a/montreal_forced_aligner/command_line/anchor.py b/montreal_forced_aligner/command_line/anchor.py index 157ccdec..5e99b506 100644 --- a/montreal_forced_aligner/command_line/anchor.py +++ b/montreal_forced_aligner/command_line/anchor.py @@ -22,7 +22,8 @@ def anchor_cli(*args, **kwargs) -> None: # pragma: no cover """ try: from anchor.command_line import main - except ImportError: + except ImportError as e: + logger.error(f"Exception: {e}") logger.error( "Anchor annotator utility is not installed, please install it via `conda install -c conda-forge anchor-annotator`." ) diff --git a/montreal_forced_aligner/command_line/g2p.py b/montreal_forced_aligner/command_line/g2p.py index 5be6517a..24a000ca 100644 --- a/montreal_forced_aligner/command_line/g2p.py +++ b/montreal_forced_aligner/command_line/g2p.py @@ -65,6 +65,18 @@ help="Included words enclosed by brackets, job_name.e. [...], (...), <...>.", default=False, ) +@click.option( + "--export_scores", + is_flag=True, + help="Add a column to export for the score of the generated pronunciation.", + default=False, +) +@click.option( + "--sorted", + is_flag=True, + help="Ensure output file is sorted alphabetically (slower).", + default=False, +) @common_options @click.help_option("-h", "--help") @click.pass_context @@ -83,6 +95,7 @@ def g2p_cli(context, **kwargs) -> None: dictionary_path = kwargs.get("dictionary_path", None) use_stdin = input_path == pathlib.Path("-") use_stdout = output_path == pathlib.Path("-") + export_scores = kwargs.get("export_scores", False) if input_path.is_dir(): per_utterance = False @@ -134,14 +147,22 @@ def g2p_cli(context, **kwargs) -> None: continue pronunciations = g2p.rewriter(word) if not pronunciations: - output.write(f"{word}\t\n") - for p in pronunciations: - output.write(f"{word}\t{p}\n") + if export_scores: + output.write(f"{word}\t\t\n") + else: + output.write(f"{word}\t\n") + for p, score in pronunciations: + if export_scores: + output.write(f"{word}\t{p}\t{score}\n") + else: + output.write(f"{word}\t{p}\n") output.flush() finally: output.close() else: - g2p.export_pronunciations(output_path) + g2p.export_pronunciations( + output_path, export_scores=export_scores, ensure_sorted=kwargs.get("sorted", False) + ) except Exception: g2p.dirty = True raise diff --git a/montreal_forced_aligner/command_line/mfa.py b/montreal_forced_aligner/command_line/mfa.py index 8008b90a..739e931f 100644 --- a/montreal_forced_aligner/command_line/mfa.py +++ b/montreal_forced_aligner/command_line/mfa.py @@ -32,7 +32,11 @@ from montreal_forced_aligner.command_line.train_ivector_extractor import train_ivector_cli from montreal_forced_aligner.command_line.train_lm import train_lm_cli from montreal_forced_aligner.command_line.train_tokenizer import train_tokenizer_cli -from montreal_forced_aligner.command_line.transcribe import transcribe_corpus_cli +from montreal_forced_aligner.command_line.transcribe import ( + transcribe_corpus_cli, + transcribe_speechbrain_cli, + transcribe_whisper_cli, +) from montreal_forced_aligner.command_line.validate import ( validate_corpus_cli, validate_dictionary_cli, @@ -194,6 +198,8 @@ def version_cli(): mfa_cli.add_command(train_lm_cli) mfa_cli.add_command(train_tokenizer_cli) mfa_cli.add_command(transcribe_corpus_cli) +mfa_cli.add_command(transcribe_speechbrain_cli) +mfa_cli.add_command(transcribe_whisper_cli) mfa_cli.add_command(validate_corpus_cli) mfa_cli.add_command(validate_dictionary_cli) mfa_cli.add_command(version_cli) diff --git a/montreal_forced_aligner/command_line/transcribe.py b/montreal_forced_aligner/command_line/transcribe.py index 70ee6015..b379221f 100644 --- a/montreal_forced_aligner/command_line/transcribe.py +++ b/montreal_forced_aligner/command_line/transcribe.py @@ -12,9 +12,14 @@ validate_dictionary, validate_language_model, ) -from montreal_forced_aligner.transcription import Transcriber +from montreal_forced_aligner.data import Language +from montreal_forced_aligner.transcription.transcriber import ( + SpeechbrainTranscriber, + Transcriber, + WhisperTranscriber, +) -__all__ = ["transcribe_corpus_cli"] +__all__ = ["transcribe_corpus_cli", "transcribe_speechbrain_cli", "transcribe_whisper_cli"] @click.command( @@ -39,7 +44,7 @@ @click.option( "--config_path", "-c", - help="Path to config file to use for training.", + help="Path to config file to use for transcription.", type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path), ) @click.option( @@ -132,3 +137,195 @@ def transcribe_corpus_cli(context, **kwargs) -> None: raise finally: transcriber.cleanup() + + +@click.command( + name="transcribe_speechbrain", + context_settings=dict( + ignore_unknown_options=True, + allow_extra_args=True, + allow_interspersed_args=True, + ), + short_help="Transcribe utterances using an ASR model trained by SpeechBrain", +) +@click.argument( + "corpus_directory", + type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path), +) +@click.argument( + "language", + type=click.Choice( + sorted( + [ + "arabic", + "german", + "english", + "spanish", + "french", + "italian", + "kinyarwanda", + "portuguese", + "mandarin", + ] + ) + ), +) +@click.argument( + "output_directory", type=click.Path(file_okay=False, dir_okay=True, path_type=Path) +) +@click.option( + "--architecture", + help="ASR model architecture", + default=SpeechbrainTranscriber.ARCHITECTURES[0], + type=click.Choice(SpeechbrainTranscriber.ARCHITECTURES), +) +@click.option( + "--config_path", + "-c", + help="Path to config file to use for transcription.", + type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path), +) +@click.option( + "--speaker_characters", + "-s", + help="Number of characters of file names to use for determining speaker, " + "default is to use directory names.", + type=str, + default="0", +) +@click.option( + "--audio_directory", + "-a", + help="Audio directory root to use for finding audio files.", + type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path), +) +@click.option( + "--cuda/--no_cuda", + "cuda", + help="Flag for using CUDA for Whisper's model", + default=False, +) +@click.option( + "--evaluate", + "evaluation_mode", + is_flag=True, + help="Evaluate the transcription against golden texts.", + default=False, +) +@common_options +@click.help_option("-h", "--help") +@click.pass_context +def transcribe_speechbrain_cli(context, **kwargs) -> None: + """ + Transcribe utterances using an ASR model trained by SpeechBrain. + """ + if kwargs.get("profile", None) is not None: + config.profile = kwargs.pop("profile") + config.update_configuration(kwargs) + + config_path = kwargs.get("config_path", None) + corpus_directory = kwargs["corpus_directory"].absolute() + output_directory = kwargs["output_directory"] + transcriber = SpeechbrainTranscriber( + corpus_directory=corpus_directory, + **SpeechbrainTranscriber.parse_parameters(config_path, context.params, context.args), + ) + try: + transcriber.setup() + transcriber.transcribe() + transcriber.export_files(output_directory) + except Exception: + transcriber.dirty = True + raise + finally: + transcriber.cleanup() + + +@click.command( + name="transcribe_whisper", + context_settings=dict( + ignore_unknown_options=True, + allow_extra_args=True, + allow_interspersed_args=True, + ), + short_help="Transcribe utterances using a Whisper ASR model via faster-whisper", +) +@click.argument( + "corpus_directory", + type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path), +) +@click.argument( + "output_directory", type=click.Path(file_okay=False, dir_okay=True, path_type=Path) +) +@click.option( + "--architecture", + help="Model size to use", + default=WhisperTranscriber.ARCHITECTURES[0], + type=click.Choice(WhisperTranscriber.ARCHITECTURES), +) +@click.option( + "--language", + help="Language to use for transcription.", + default=Language.unknown.name, + type=click.Choice([x.name for x in Language]), +) +@click.option( + "--config_path", + "-c", + help="Path to config file to use for transcription.", + type=click.Path(exists=True, file_okay=True, dir_okay=False, path_type=Path), +) +@click.option( + "--speaker_characters", + "-s", + help="Number of characters of file names to use for determining speaker, " + "default is to use directory names.", + type=str, + default="0", +) +@click.option( + "--audio_directory", + "-a", + help="Audio directory root to use for finding audio files.", + type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path), +) +@click.option( + "--cuda/--no_cuda", + "cuda", + help="Flag for using CUDA for Whisper's model", + default=False, +) +@click.option( + "--evaluate", + "evaluation_mode", + is_flag=True, + help="Evaluate the transcription against golden texts.", + default=False, +) +@common_options +@click.help_option("-h", "--help") +@click.pass_context +def transcribe_whisper_cli(context, **kwargs) -> None: + """ + Transcribe utterances using a Whisper ASR model via faster-whisper. + """ + if kwargs.get("profile", None) is not None: + config.profile = kwargs.pop("profile") + config.update_configuration(kwargs) + + config_path = kwargs.get("config_path", None) + corpus_directory = kwargs["corpus_directory"].absolute() + output_directory = kwargs["output_directory"] + transcriber = WhisperTranscriber( + corpus_directory=corpus_directory, + **WhisperTranscriber.parse_parameters(config_path, context.params, context.args), + ) + try: + transcriber.setup() + transcriber.transcribe() + transcriber.export_files(output_directory) + except Exception: + transcriber.dirty = True + raise + finally: + transcriber.cleanup() diff --git a/montreal_forced_aligner/command_line/validate.py b/montreal_forced_aligner/command_line/validate.py index 690a9444..1a7d3626 100644 --- a/montreal_forced_aligner/command_line/validate.py +++ b/montreal_forced_aligner/command_line/validate.py @@ -34,6 +34,11 @@ type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path), ) @click.argument("dictionary_path", type=click.UNPROCESSED, callback=validate_dictionary) +@click.option( + "--output_directory", + help="Directory to save validation output files.", + type=click.Path(exists=False, file_okay=False, dir_okay=True, path_type=Path), +) @click.option( "--acoustic_model_path", help="Acoustic model to use in testing alignments.", @@ -108,6 +113,7 @@ def validate_corpus_cli(context, **kwargs) -> None: config.update_configuration(kwargs) kwargs["USE_THREADING"] = False + output_directory = kwargs.get("output_directory", None) config_path = kwargs.get("config_path", None) corpus_directory = kwargs["corpus_directory"].absolute() dictionary_path = kwargs["dictionary_path"] @@ -135,7 +141,7 @@ def validate_corpus_cli(context, **kwargs) -> None: **TrainingValidator.parse_parameters(config_path, context.params, context.args), ) try: - validator.validate() + validator.validate(output_directory=output_directory) except Exception: validator.dirty = True raise diff --git a/montreal_forced_aligner/config.py b/montreal_forced_aligner/config.py index 4ef7328a..a878c214 100644 --- a/montreal_forced_aligner/config.py +++ b/montreal_forced_aligner/config.py @@ -152,6 +152,7 @@ def update_command_history(command_data: Dict[str, Any]) -> None: AUTO_SERVER = True TEMPORARY_DIRECTORY = get_temporary_directory() GITHUB_TOKEN = None +HF_TOKEN = None BLAS_NUM_THREADS = 1 BYTES_LIMIT = 100e6 CURRENT_PROFILE_NAME = os.getenv(MFA_PROFILE_VARIABLE, "global") @@ -187,6 +188,7 @@ class MfaProfile: auto_server: bool = True temporary_directory: pathlib.Path = get_temporary_directory() github_token: typing.Optional[str] = None + hf_token: typing.Optional[str] = None def __getitem__(self, item): """Get key from profile""" diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py index 2ee44541..1bd996ab 100644 --- a/montreal_forced_aligner/corpus/base.py +++ b/montreal_forced_aligner/corpus/base.py @@ -654,6 +654,7 @@ def normalize_text_arguments(self): tokenizers, getattr(self, "g2p_model", None), getattr(self, "ignore_case", True), + getattr(self, "use_cutoff_model", False), ) for j in jobs ] @@ -675,7 +676,7 @@ def normalize_text(self) -> None: pronunciation_insert_mappings = [] word_indexes = {} word_mapping_ids = {} - max_mapping_ids = {} + max_mapping_id = 0 log_directory.mkdir(parents=True, exist_ok=True) update_mapping = [] word_key = self.get_next_primary_key(Word) @@ -709,14 +710,15 @@ def normalize_text(self) -> None: "dictionary_id": 1, } word_key += 1 - max_mapping_ids[1] = word_key - 1 + max_mapping_id = word_key - 1 for w_id, m_id, d_id, w, wt in words: if wt is WordType.oov and w not in self.specials_set: existing_oovs[(d_id, w)] = {"id": w_id, "count": 0, "included": False} continue word_indexes[(d_id, w)] = w_id - word_mapping_ids[(d_id, w)] = m_id - max_mapping_ids[d_id] = m_id + word_mapping_ids[w] = m_id + if m_id > max_mapping_id: + max_mapping_id = m_id to_g2p = set() word_to_g2p_mapping = {x: collections.defaultdict(set) for x in dictionaries.keys()} word_counts = collections.defaultdict(int) @@ -832,7 +834,7 @@ def normalize_text(self) -> None: log_file.write(f"For dictionary {dict_id}:\n") for w, ps in mapping.items(): log_file.write(f" - {w} ({', '.join(sorted(ps))})\n") - max_mapping_ids[dict_id] += 1 + max_mapping_id += 1 included = False if hasattr(self, "brackets") and any( w.startswith(b) for b, _ in self.brackets @@ -859,7 +861,7 @@ def normalize_text(self) -> None: word_insert_mappings[(dict_id, w)] = { "id": word_key, - "mapping_id": max_mapping_ids[d_id], + "mapping_id": max_mapping_id, "word": w, "count": word_counts[(dict_id, w)], "dictionary_id": dict_id, diff --git a/montreal_forced_aligner/corpus/features.py b/montreal_forced_aligner/corpus/features.py index 335d809a..9db52787 100644 --- a/montreal_forced_aligner/corpus/features.py +++ b/montreal_forced_aligner/corpus/features.py @@ -595,7 +595,7 @@ class FeatureConfigMixin: def __init__( self, feature_type: str = "mfcc", - use_energy: bool = True, + use_energy: bool = False, raw_energy: bool = False, frame_shift: int = 10, frame_length: int = 25, @@ -605,8 +605,8 @@ def __init__( sample_frequency: int = 16000, allow_downsample: bool = True, allow_upsample: bool = True, - dither: float = 0.0001, - energy_floor: float = 1.0, + dither: float = 0.0, + energy_floor: float = 0.0, num_coefficients: int = 13, num_mel_bins: int = 23, cepstral_lifter: float = 22, diff --git a/montreal_forced_aligner/corpus/multiprocessing.py b/montreal_forced_aligner/corpus/multiprocessing.py index 71049a63..dad3e5a7 100644 --- a/montreal_forced_aligner/corpus/multiprocessing.py +++ b/montreal_forced_aligner/corpus/multiprocessing.py @@ -236,6 +236,7 @@ class NormalizeTextArguments(MfaArguments): tokenizers: typing.Union[typing.Dict[int, SimpleTokenizer], Language] g2p_model: typing.Optional[G2PModel] ignore_case: bool + use_cutoff_model: bool @dataclass @@ -265,6 +266,7 @@ def __init__(self, args: NormalizeTextArguments): self.tokenizers = args.tokenizers self.g2p_model = args.g2p_model self.ignore_case = args.ignore_case + self.use_cutoff_model = args.use_cutoff_model def _run(self): """Run the function""" @@ -300,6 +302,18 @@ def _run(self): for u_id, u_text in utterances: if simple_tokenization: normalized_text, normalized_character_text, oovs = tokenizer(u_text) + if self.use_cutoff_model: + new_text = [] + text = normalized_text.split() + for i, w in enumerate(text): + if w == d.cutoff_word and i != len(text) - 1: + next_w = text[i + 1] + if tokenizer.word_table.member( + next_w + ) and not tokenizer.bracket_regex.match(next_w): + w = f"{d.cutoff_word[:-1]}-{next_w}{d.cutoff_word[-1]}" + new_text.append(w) + normalized_text = " ".join(new_text) self.callback( ( { @@ -335,6 +349,14 @@ def _run(self): ) else: tokenizer = self.tokenizers + if isinstance(tokenizer, Language): + from montreal_forced_aligner.tokenization.spacy import ( + generate_language_tokenizer, + ) + + tokenizer = generate_language_tokenizer( + tokenizer, ignore_case=self.ignore_case + ) utterances = ( session.query(Utterance.id, Utterance.text) .filter(Utterance.text != "") @@ -342,17 +364,23 @@ def _run(self): ) for u_id, u_text in utterances: if tokenizer is None: - normalized_text, normalized_character_text = u_text, u_text - oovs = [] + normalized_text, pronunciation_form = u_text, u_text else: - normalized_text, normalized_character_text, oovs = tokenizer(u_text) + tokenized = tokenizer(u_text) + if isinstance(tokenized, tuple): + normalized_text, pronunciation_form = tokenized[:2] + else: + if not isinstance(tokenized, str): + tokenized = " ".join([x.text for x in tokenized]) + if self.ignore_case: + tokenized = tokenized.lower() + normalized_text, pronunciation_form = tokenized, tokenized.lower() self.callback( ( { "id": u_id, - "oovs": " ".join(sorted(oovs)), "normalized_text": normalized_text, - "normalized_character_text": normalized_character_text, + "normalized_character_text": pronunciation_form, }, None, ) diff --git a/montreal_forced_aligner/data.py b/montreal_forced_aligner/data.py index b3b63657..2b9f2eca 100644 --- a/montreal_forced_aligner/data.py +++ b/montreal_forced_aligner/data.py @@ -273,8 +273,8 @@ class MfaArguments: ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run """ @@ -475,6 +475,112 @@ class ClusterType(enum.Enum): meanshift = "meanshift" +ISO_LANGUAGE_MAPPING = { + "afrikaans": "af", + "amharic": "am", + "arabic": "ar", + "assamese": "as", + "azerbaijani": "az", + "bashkir": "ba", + "belarusian": "be", + "bulgarian": "bg", + "bengali": "bn", + "tibetan": "bo", + "breton": "br", + "bosnian": "bs", + "catalan": "ca", + "czech": "cs", + "welsh": "cy", + "danish": "da", + "german": "de", + "greek": "el", + "english": "en", + "spanish": "es", + "estonian": "et", + "basque": "eu", + "farsi": "fa", + "finnish": "fi", + "faroese": "fo", + "french": "fr", + "galician": "gl", + "gujarati": "gu", + "hausa": "ha", + "hebrew": "he", + "hindi": "hi", + "croatian": "hr", + "haitian": "ht", + "hungarian": "hu", + "armenian": "hy", + "indonesian": "id", + "icelandic": "is", + "italian": "it", + "japanese": "ja", + "georgian": "ka", + "kazakh": "kk", + "central khmer": "km", + "kannada": "kn", + "korean": "ko", + "latin": "la", + "luxembourgish": "lb", + "lingala": "ln", + "lao": "lo", + "lithuanian": "lt", + "latvian": "lv", + "malagasy": "mg", + "maori": "mi", + "macedonian": "mk", + "malayalam": "ml", + "mongolian": "mn", + "marathi": "mr", + "malay": "ms", + "maltese": "mt", + "burmese": "my", + "nepali": "ne", + "dutch": "nl", + "flemish": "nl", + "norwegian nynorsk": "nn", + "norwegian": "no", + "occitan": "oc", + "punjabi": "pa", + "polish": "pl", + "pashto": "ps", + "portuguese": "pt", + "romanian": "ro", + "moldavian": "ro", + "russian": "ru", + "sanskrit": "sa", + "sindhi": "sd", + "sinhala": "si", + "slovak": "sk", + "slovenian": "sl", + "shona": "sn", + "somali": "so", + "albanian": "sq", + "serbian": "sr", + "sundanese": "su", + "swedish": "sv", + "swahili": "sw", + "tamil": "ta", + "telegu": "te", + "tajik": "tg", + "thai": "th", + "turkmen": "tk", + "tagalog": "tl", + "turkish": "tr", + "tatar": "tt", + "ukrainian": "uk", + "urdu": "ur", + "uzbek": "uz", + "vietnamese": "vi", + "yiddish": "yi", + "yoruba": "yo", + "yue": "yue", + "chinese": "zh", + "kinyarwanda": "rw", + "mandarin": "zh-CN", +} + + class Language(enum.Enum): """Enum for supported languages""" @@ -510,6 +616,12 @@ def __str__(self) -> str: """Name of phone set""" return self.name + @property + def iso_code(self) -> typing.Optional[str]: + if self.value in ISO_LANGUAGE_MAPPING: + return ISO_LANGUAGE_MAPPING[self.value] + return None + class ManifoldAlgorithm(enum.Enum): """Enum for supported manifold visualization algorithms""" diff --git a/montreal_forced_aligner/db.py b/montreal_forced_aligner/db.py index 92f7c432..801d0ad9 100644 --- a/montreal_forced_aligner/db.py +++ b/montreal_forced_aligner/db.py @@ -406,10 +406,13 @@ def word_mapping(self): session = sqlalchemy.orm.Session.object_session(self) query = ( session.query(Word.word, Word.mapping_id) - .filter(Word.dictionary_id == self.id) .filter(Word.included == True) # noqa .order_by(Word.mapping_id) ) + if self.name != "default": + query = query.filter(Word.dictionary_id == self.id) + else: + query = query.group_by(Word.word, Word.mapping_id) self._word_mapping = {} for w, mapping_id in query: self._word_mapping[w] = mapping_id @@ -425,10 +428,15 @@ def word_table(self): session = sqlalchemy.orm.Session.object_session(self) query = ( session.query(Word.word, Word.mapping_id) - .filter(Word.dictionary_id == self.id) .filter(Word.included == True) # noqa .order_by(Word.mapping_id) ) + if self.name != "default": + query = query.filter( + sqlalchemy.or_(Word.dictionary_id == self.id, Word.word.in_(self.special_set)) + ) + else: + query = query.group_by(Word.word, Word.mapping_id) self._word_table = pywrapfst.SymbolTable() for w, mapping_id in query: self._word_table.add_symbol(w, mapping_id) @@ -462,11 +470,14 @@ def word_pronunciations(self): query = ( session.query(Word.word, Pronunciation.pronunciation) .join(Pronunciation.word) - .filter(Word.dictionary_id == self.id) .filter(Word.included == True) # noqa .filter(Pronunciation.pronunciation != self.oov_phone) .order_by(Word.mapping_id) ) + if self.name != "default": + query = query.filter(Word.dictionary_id == self.id) + else: + query = query.group_by(Word.word, Pronunciation.pronunciation) self._word_pronunciations = {} for w, pronunciation in query: if w not in self._word_pronunciations: diff --git a/montreal_forced_aligner/diarization/multiprocessing.py b/montreal_forced_aligner/diarization/multiprocessing.py index b5150360..49628921 100644 --- a/montreal_forced_aligner/diarization/multiprocessing.py +++ b/montreal_forced_aligner/diarization/multiprocessing.py @@ -9,7 +9,6 @@ import time import typing from pathlib import Path -from queue import Queue import dataclassy @@ -56,11 +55,13 @@ except ImportError: # speechbrain 1.0 from speechbrain.inference.classifiers import EncoderClassifier from speechbrain.inference.speaker import SpeakerRecognition + from speechbrain.utils.metric_stats import EER FOUND_SPEECHBRAIN = True except (ImportError, OSError): FOUND_SPEECHBRAIN = False EncoderClassifier = None SpeakerRecognition = None + EER = None __all__ = [ "PldaClassificationArguments", @@ -700,7 +701,7 @@ def _run(self) -> None: run_opts=run_opts, ) - return_q = Queue(2) + return_q = queue.Queue(2) finished_adding = threading.Event() stopped = threading.Event() loader = UtteranceFileLoader( @@ -756,7 +757,7 @@ class UtteranceFileLoader(threading.Thread): Job identifier session: sqlalchemy.orm.scoped_session Session - return_q: multiprocessing.Queue + return_q: :class:`~queue.Queue` Queue to put waveforms stopped: :class:`~threading.Event` Check for whether the process to exit gracefully @@ -768,9 +769,11 @@ def __init__( self, job_name: int, session: sqlalchemy.orm.scoped_session, - return_q: Queue, + return_q: queue.Queue, stopped: threading.Event, finished_adding: threading.Event, + model=None, + for_xvector=True, ): super().__init__() self.job_name = job_name @@ -778,6 +781,8 @@ def __init__( self.return_q = return_q self.stopped = stopped self.finished_adding = finished_adding + self.model = model + self.for_xvector = for_xvector def run(self) -> None: """ @@ -790,7 +795,10 @@ def run(self) -> None: @speechbrain.utils.data_pipeline.takes("segment") @speechbrain.utils.data_pipeline.provides("signal") def audio_pipeline(segment): - return segment.load_audio() + signal = torch.tensor(segment.load_audio()) + if self.model is not None: + signal = self.model.audio_normalizer(signal, 16000) + return signal with self.session() as session: try: @@ -804,9 +812,12 @@ def audio_pipeline(segment): ) .join(Utterance.file) .join(File.sound_file) - .filter(Utterance.xvector == None) # noqa .order_by(Utterance.duration.desc()) ) + if self.for_xvector: + utterances = utterances.filter(Utterance.xvector == None) # noqa + else: + utterances = utterances.filter(Utterance.job_id == self.job_name) if not utterances.count(): self.finished_adding.set() return diff --git a/montreal_forced_aligner/diarization/speaker_diarizer.py b/montreal_forced_aligner/diarization/speaker_diarizer.py index 5509e8f0..adc520a4 100644 --- a/montreal_forced_aligner/diarization/speaker_diarizer.py +++ b/montreal_forced_aligner/diarization/speaker_diarizer.py @@ -55,10 +55,14 @@ bulk_update, ) from montreal_forced_aligner.diarization.multiprocessing import ( + EER, + FOUND_SPEECHBRAIN, ComputeEerArguments, ComputeEerFunction, + EncoderClassifier, PldaClassificationArguments, PldaClassificationFunction, + SpeakerRecognition, SpeechbrainArguments, SpeechbrainClassificationFunction, SpeechbrainEmbeddingFunction, @@ -71,29 +75,6 @@ from montreal_forced_aligner.textgrid import construct_output_path, export_textgrid from montreal_forced_aligner.utils import log_kaldi_errors, run_kaldi_function, thirdparty_binary -try: - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - torch_logger = logging.getLogger("speechbrain.utils.torch_audio_backend") - torch_logger.setLevel(logging.ERROR) - torch_logger = logging.getLogger("speechbrain.utils.train_logger") - torch_logger.setLevel(logging.ERROR) - import torch - - try: - from speechbrain.pretrained import EncoderClassifier, SpeakerRecognition - except ImportError: # speechbrain 1.0 - from speechbrain.inference.classifiers import EncoderClassifier - from speechbrain.inference.speaker import SpeakerRecognition - from speechbrain.utils.metric_stats import EER - - FOUND_SPEECHBRAIN = True -except (ImportError, OSError): - FOUND_SPEECHBRAIN = False - EncoderClassifier = None - if TYPE_CHECKING: from montreal_forced_aligner.abc import MetaDict @@ -1272,6 +1253,8 @@ def calculate_eer(self) -> typing.Tuple[float, float]: if not FOUND_SPEECHBRAIN: logger.info("No speechbrain found, skipping EER calculation.") return 0.0, 0.0 + import torch + logger.info("Calculating EER using ground truth speakers...") limit_per_speaker = 5 limit_within_speaker = 30 @@ -1440,6 +1423,7 @@ def refresh_speaker_vectors(self) -> None: s_ivectors.append(u_ivector) if not s_ivectors: continue + print(s_ivectors) mean_ivector = np.mean(np.array(s_ivectors), axis=0) speaker_mean = DoubleVector() speaker_mean.from_numpy(mean_ivector) diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index 80c55bb3..dd9fff2f 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -413,6 +413,8 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: graphemes = set(self.clitic_markers + self.compound_markers) clitic_cleanup_regex = None has_nonnative_speakers = False + self._words_mappings = {} + current_mapping_id = 0 if len(self.clitic_markers) >= 1: other_clitic_markers = self.clitic_markers[1:] if other_clitic_markers: @@ -520,19 +522,19 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: dictionary_id_cache[dictionary_model.path] = dictionary.id if dictionary.default: self._default_dictionary_id = dictionary.id - self._words_mappings[dictionary.id] = {} - current_index = 0 - word_objs.append( - { - "id": word_primary_key, - "mapping_id": current_index, - "word": self.silence_word, - "word_type": WordType.silence, - "dictionary_id": dictionary.id, - } - ) - self._words_mappings[dictionary.id][self.silence_word] = current_index - current_index += 1 + if self.silence_word not in self._words_mappings: + word_objs.append( + { + "id": word_primary_key, + "mapping_id": current_mapping_id, + "word": self.silence_word, + "word_type": WordType.silence, + "dictionary_id": dictionary.id, + } + ) + self._words_mappings[self.silence_word] = current_mapping_id + current_mapping_id += 1 + word_primary_key += 1 pron_objs.append( { @@ -546,7 +548,6 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: "word_id": word_primary_key, } ) - word_primary_key += 1 pronunciation_primary_key += 1 special_words = { @@ -603,17 +604,18 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: wt = special_wt specials_found.add(special_w) break + if word not in self._words_mappings: + self._words_mappings[word] = current_mapping_id + current_mapping_id += 1 word_objs.append( { "id": word_primary_key, - "mapping_id": current_index, + "mapping_id": self._words_mappings[word], "word": word, "word_type": wt, "dictionary_id": dictionary.id, } ) - self._words_mappings[dictionary.id][word] = current_index - current_index += 1 word_cache[word] = word_primary_key word_primary_key += 1 pron_string = " ".join(pron) @@ -644,20 +646,21 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: pronunciation_cache.add((word, pron_string)) phone_counts.update(pron) - for w, wt in special_words.items(): - if w in specials_found: + for word, wt in special_words.items(): + if word in specials_found: continue + if word not in self._words_mappings: + self._words_mappings[word] = current_mapping_id + current_mapping_id += 1 word_objs.append( { "id": word_primary_key, - "mapping_id": current_index, - "word": w, + "mapping_id": self._words_mappings[word], + "word": word, "word_type": wt, "dictionary_id": dictionary.id, } ) - self._words_mappings[dictionary.id][w] = current_index - current_index += 1 pron_objs.append( { "id": pronunciation_primary_key, @@ -672,19 +675,20 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: ) pronunciation_primary_key += 1 word_primary_key += 1 - for s in ["#0", "", ""]: + for word in ["#0", "", ""]: + if word not in self._words_mappings: + self._words_mappings[word] = current_mapping_id + current_mapping_id += 1 word_objs.append( { "id": word_primary_key, - "word": s, + "word": word, "dictionary_id": dictionary.id, - "mapping_id": current_index, + "mapping_id": self._words_mappings[word], "word_type": WordType.disambiguation, } ) - self._words_mappings[dictionary.id][s] = current_index word_primary_key += 1 - current_index += 1 if not graphemes: raise DictionaryFileError( @@ -765,8 +769,6 @@ def create_default_dictionary(self): if dictionary is not None: # Already a default dictionary self._default_dictionary_id = dictionary.id return - word_objs = [] - pron_objs = [] dialect = session.query(Dialect).filter_by(name="unknown").first() if dialect is None: dialect = Dialect(name="unknown") @@ -794,111 +796,8 @@ def create_default_dictionary(self): session.add(dictionary) session.commit() - special_words = { - self.silence_word: WordType.silence, - self.oov_word: WordType.oov, - self.bracketed_word: WordType.bracketed, - self.cutoff_word: WordType.cutoff, - self.laughter_word: WordType.laughter, - } self._default_dictionary_id = dictionary.id self.dictionary_lookup[dictionary.name] = dictionary.id - self._words_mappings[dictionary.id] = {} - word_primary_key = self.get_next_primary_key(Word) - pronunciation_primary_key = self.get_next_primary_key(Pronunciation) - current_index = 0 - word_cache = {} - for w, w_type in special_words.items(): - if w_type is WordType.silence: - pron = self.optional_silence_phone - else: - pron = self.oov_phone - word_objs.append( - { - "id": word_primary_key, - "mapping_id": current_index, - "word": w, - "word_type": w_type, - "dictionary_id": dictionary.id, - } - ) - self._words_mappings[dictionary.id][w] = current_index - current_index += 1 - - pron_objs.append( - { - "id": pronunciation_primary_key, - "pronunciation": pron, - "probability": 1.0, - "disambiguation": None, - "silence_after_probability": None, - "silence_before_correction": None, - "non_silence_before_correction": None, - "word_id": word_primary_key, - } - ) - word_primary_key += 1 - pronunciation_primary_key += 1 - - query = ( - session.query( - Word.word, - Word.word_type, - Pronunciation.pronunciation, - ) - .join(Pronunciation.word) - .filter(~Word.word.in_(special_words.keys())) - .distinct() - .order_by(Word.word, Pronunciation.pronunciation) - ) - for word, word_type, pronunciation in query: - if word not in word_cache: - word_objs.append( - { - "id": word_primary_key, - "mapping_id": current_index, - "word": word, - "word_type": word_type, - "dictionary_id": dictionary.id, - } - ) - self._words_mappings[dictionary.id][word] = current_index - current_index += 1 - word_cache[word] = word_primary_key - word_primary_key += 1 - pron_objs.append( - { - "id": pronunciation_primary_key, - "pronunciation": pronunciation, - "probability": None, - "disambiguation": None, - "silence_after_probability": None, - "silence_before_correction": None, - "non_silence_before_correction": None, - "word_id": word_cache[word], - } - ) - pronunciation_primary_key += 1 - for s in ["#0", "", ""]: - word_objs.append( - { - "id": word_primary_key, - "word": s, - "dictionary_id": dictionary.id, - "mapping_id": current_index, - "word_type": WordType.disambiguation, - } - ) - self._words_mappings[dictionary.id][s] = current_index - word_primary_key += 1 - current_index += 1 - with self.session() as session: - with session.bind.begin() as conn: - if word_objs: - conn.execute(sqlalchemy.insert(Word.__table__), word_objs) - if pron_objs: - conn.execute(sqlalchemy.insert(Pronunciation.__table__), pron_objs) - session.commit() def create_nonnative_dictionary(self): with self.session() as session: @@ -907,7 +806,6 @@ def create_nonnative_dictionary(self): pron_objs = [] self.dictionary_lookup[dictionary.name] = dictionary.id - self._words_mappings[dictionary.id] = {} word_primary_key = self.get_next_primary_key(Word) pronunciation_primary_key = self.get_next_primary_key(Pronunciation) word_cache = {} @@ -919,7 +817,6 @@ def create_nonnative_dictionary(self): Pronunciation.pronunciation, ) .join(Word.pronunciations, isouter=True) - .filter(Word.dictionary_id == self._default_dictionary_id) .distinct() .order_by(Word.mapping_id) ) @@ -934,7 +831,6 @@ def create_nonnative_dictionary(self): "dictionary_id": dictionary.id, } ) - self._words_mappings[dictionary.id][word] = mapping_id word_cache[word] = word_primary_key word_primary_key += 1 if pronunciation is not None: @@ -1588,27 +1484,19 @@ def find_all_cutoffs(self) -> None: cutoff_identifier = re.sub( rf"[{initial_brackets}{final_brackets}]", "", self.cutoff_word ) - max_ids = collections.defaultdict(int) max_pron_id = session.query(sqlalchemy.func.max(Pronunciation.id)).scalar() max_word_id = session.query(sqlalchemy.func.max(Word.id)).scalar() - new_word_mapping = {} - new_pronunciation_mapping = [] - for d_id, max_id in ( - session.query(Dictionary.id, sqlalchemy.func.max(Word.mapping_id)) - .join(Word.dictionary) - .group_by(Dictionary.id) - ): - max_ids[d_id] = max_id for d_id in self.dictionary_lookup.values(): pronunciation_mapping = collections.defaultdict(set) - word_mapping = {} + new_word_mapping = {} + new_pronunciation_mapping = [] max_id = ( session.query(sqlalchemy.func.max(Word.mapping_id)) .join(Word.dictionary) .filter(Dictionary.id == d_id) ).first()[0] words = ( - session.query(Word.mapping_id, Word.word, Pronunciation.pronunciation) + session.query(Word.word, Pronunciation.pronunciation) .join(Pronunciation.word) .filter( Word.dictionary_id == d_id, @@ -1616,8 +1504,7 @@ def find_all_cutoffs(self) -> None: Word.word_type == WordType.speech, ) ) - for m_id, w, pron in words: - word_mapping[w] = m_id + for w, pron in words: pronunciation_mapping[w].add(pron) new_word = f"{self.cutoff_word[:-1]}-{w}{self.cutoff_word[-1]}" if new_word not in new_word_mapping: @@ -1638,11 +1525,13 @@ def find_all_cutoffs(self) -> None: "dictionary_id": d_id, "mapping_id": max_id, "word_type": WordType.cutoff, + "count": 0, + "included": False, } - word_mapping[new_word] = max_id p = pron.split() for pi in range(len(p)): new_p = " ".join(p[: pi + 1]) + if new_p in pronunciation_mapping[new_word]: continue pronunciation_mapping[new_word].add(new_p) @@ -1663,49 +1552,37 @@ def find_all_cutoffs(self) -> None: .filter( Speaker.dictionary_id == d_id, Utterance.normalized_text.regexp_match( - f"[{initial_brackets}]({cutoff_identifier}|hes)" + f"[{initial_brackets}]{cutoff_identifier}" ), ) ) - utterance_mapping = [] for u_id, normalized_text in utterances: text = normalized_text.split() - modified = False for i, word in enumerate(text): - m = re.match( - f"^[{initial_brackets}]({cutoff_identifier}|hes(itation)?)([-_](?P[^{final_brackets}]+))?[{final_brackets}]$", - word, - ) - if not m: + if not word.startswith(f"{self.cutoff_word[:-1]}-"): continue - next_word = m.group("word") - new_word = f"{self.cutoff_word[:-1]}-{next_word}{self.cutoff_word[-1]}" - if ( - next_word is None - or next_word not in word_mapping - or self.oov_phone in pronunciation_mapping[next_word] - or self.optional_silence_phone in pronunciation_mapping[next_word] - or new_word not in word_mapping - ): + if word not in new_word_mapping: continue - text[i] = new_word - modified = True - if modified: - utterance_mapping.append( - { - "id": u_id, - "normalized_text": " ".join(text), - } - ) - session.bulk_insert_mappings( - Word, new_word_mapping.values(), return_defaults=False, render_nulls=True - ) - session.bulk_insert_mappings( - Pronunciation, new_pronunciation_mapping, return_defaults=False, render_nulls=True - ) - bulk_update(session, Utterance, utterance_mapping) + new_word_mapping[word]["count"] += 1 + new_word_mapping[word]["included"] = True + session.bulk_insert_mappings( + Word, new_word_mapping.values(), return_defaults=False, render_nulls=True + ) + session.bulk_insert_mappings( + Pronunciation, + new_pronunciation_mapping, + return_defaults=False, + render_nulls=True, + ) + session.flush() session.query(Corpus).update({"cutoffs_found": True}) session.commit() + oov_count_threshold = getattr(self, "oov_count_threshold", 0) + if oov_count_threshold > 0: + session.query(Word).filter(Word.word_type == WordType.cutoff).filter( + Word.count <= oov_count_threshold + ).update({Word.included: False}) + session.commit() self._words_mappings = {} def write_lexicon_information(self, write_disambiguation: Optional[bool] = False) -> None: @@ -1783,33 +1660,67 @@ def build_lexicon_compiler( else: lexicon_compiler = acoustic_model.lexicon_compiler lexicon_compiler.disambiguation = disambiguation - query = ( - session.query(Word, Pronunciation) - .join(Pronunciation.word) - .filter(Word.dictionary_id == d.id) - .filter(Word.included == True) # noqa - .filter(Word.word_type != WordType.silence) - .order_by(Word.word) - ) + if d.name != "default": + query = ( + session.query( + Word.mapping_id, + Word.word, + Pronunciation.pronunciation, + Pronunciation.probability, + Pronunciation.silence_after_probability, + Pronunciation.silence_before_correction, + Pronunciation.non_silence_before_correction, + ) + .join(Pronunciation.word) + .filter(Word.dictionary_id == d.id) + .filter(Word.included == True) # noqa + .filter(Word.word_type != WordType.silence) + .order_by(Word.word) + ) + else: + query = ( + session.query( + Word.mapping_id, + Word.word, + Pronunciation.pronunciation, + sqlalchemy.func.avg(Pronunciation.probability), + sqlalchemy.func.avg(Pronunciation.silence_after_probability), + sqlalchemy.func.avg(Pronunciation.silence_before_correction), + sqlalchemy.func.avg(Pronunciation.non_silence_before_correction), + ) + .join(Pronunciation.word) + .filter(Word.included == True) # noqa + .filter(Word.word_type != WordType.silence) + .group_by(Word.mapping_id, Word.word, Pronunciation.pronunciation) + .order_by(Word.mapping_id) + ) lexicon_compiler.word_table = d.word_table - for w, p in query: - phones = p.pronunciation.split() + for ( + mapping_id, + word, + pronunciation, + probability, + silence_after_probability, + silence_before_correction, + non_silence_before_correction, + ) in query: + phones = pronunciation.split() if self.position_dependent_phones: if any(not lexicon_compiler.phone_table.member(x + "_S") for x in phones): continue else: if any(not lexicon_compiler.phone_table.member(x) for x in phones): continue - if not lexicon_compiler.word_table.member(w.word): - lexicon_compiler.word_table.add_symbol(w.word, w.mapping_id) + if not lexicon_compiler.word_table.member(word): + lexicon_compiler.word_table.add_symbol(word, mapping_id) lexicon_compiler.pronunciations.append( KalpyPronunciation( - w.word, - p.pronunciation, - p.probability, - p.silence_after_probability, - p.silence_before_correction, - p.non_silence_before_correction, + word, + pronunciation, + probability, + silence_after_probability, + silence_before_correction, + non_silence_before_correction, None, ) ) diff --git a/montreal_forced_aligner/exceptions.py b/montreal_forced_aligner/exceptions.py index f82dea66..0f5b4f6f 100644 --- a/montreal_forced_aligner/exceptions.py +++ b/montreal_forced_aligner/exceptions.py @@ -367,15 +367,13 @@ class DictionaryFileError(DictionaryError): Parameters ---------- - input_path: :class:`~pathlib.Path` - Path of the pronunciation dictionary + message: str + Error message """ - def __init__(self, input_path: Path): + def __init__(self, message: str): super().__init__("") - self.message_lines = [ - f"The specified path for the dictionary ({input_path}) is not a file." - ] + self.message_lines = [message] # Corpus Errors diff --git a/montreal_forced_aligner/g2p/generator.py b/montreal_forced_aligner/g2p/generator.py index 382d18ee..eba43f1b 100644 --- a/montreal_forced_aligner/g2p/generator.py +++ b/montreal_forced_aligner/g2p/generator.py @@ -92,6 +92,17 @@ def threshold_lattice_to_dfa( return lattice +def generate_scored_paths(lattice, output_token_type): + paths = lattice.paths(output_token_type=output_token_type) + output = [] + while not paths.done(): + score = float(paths.weight()) + ostring = paths.ostring() + output.append((ostring, score)) + paths.next() + return output + + def optimal_rewrites( string: pynini.FstLike, rule: pynini.Fst, @@ -111,7 +122,31 @@ def optimal_rewrites( """ lattice = rewrite.rewrite_lattice(string, rule, input_token_type) lattice = threshold_lattice_to_dfa(lattice, threshold, 4) - return rewrite.lattice_to_strings(lattice, output_token_type) + return generate_scored_paths(lattice, output_token_type) + + +def scored_top_rewrites( + string: pynini.FstLike, + rule: pynini.Fst, + nshortest: int, + input_token_type: Optional[pynini.TokenType] = None, + output_token_type: Optional[pynini.TokenType] = None, +) -> List[str]: + """Returns the top n rewrites. + + Args: + string: Input string or FST. + rule: Input rule WFST. + nshortest: The maximum number of rewrites to return. + input_token_type: Optional input token type, or symbol table. + output_token_type: Optional output token type, or symbol table. + + Returns: + A list of output strings. + """ + lattice = rewrite.rewrite_lattice(string, rule, input_token_type) + lattice = rewrite.lattice_to_nshortest(lattice, nshortest) + return generate_scored_paths(lattice, output_token_type) class Rewriter: @@ -124,7 +159,7 @@ class Rewriter: G2P FST model input_token_type: pynini.TokenType Grapheme symbol table or "utf8" - output_token_type: pynini.SymbolTable + phone_symbol_table: pynini.SymbolTable Phone symbol table num_pronunciations: int Number of pronunciations, default to 0. If this is 0, thresholding is used @@ -135,7 +170,7 @@ class Rewriter: def __init__( self, fst: Fst, - input_token_type: TokenType, + grapheme_symbol_table: SymbolTable, phone_symbol_table: SymbolTable, num_pronunciations: int = 0, threshold: float = 1, @@ -143,12 +178,12 @@ def __init__( strict: bool = False, ): self.graphemes = graphemes - self.input_token_type = input_token_type + self.grapheme_symbol_table = grapheme_symbol_table self.phone_symbol_table = phone_symbol_table self.strict = strict if num_pronunciations > 0: self.rewrite = functools.partial( - rewrite.top_rewrites, + scored_top_rewrites, nshortest=num_pronunciations, rule=fst, input_token_type=None, @@ -168,7 +203,7 @@ def create_word_fst(self, word: str) -> pynini.Fst: if self.strict and any(x not in self.graphemes for x in word): return None word = "".join([x for x in word if x in self.graphemes]) - fst = pynini.accep(word, token_type=self.input_token_type) + fst = pynini.accep(word, token_type=self.grapheme_symbol_table) return fst def __call__(self, graphemes: str) -> List[str]: # pragma: no cover @@ -181,16 +216,22 @@ def __call__(self, graphemes: str) -> List[str]: # pragma: no cover if not w_fst: continue hypotheses.append(self.rewrite(w_fst)) - hypotheses = sorted(set(" ".join(x) for x in itertools.product(*hypotheses))) + hypotheses = sorted( + ( + (" ".join([y[0] for y in x]), sum([y[1] for y in x])) + for x in itertools.product(*hypotheses) + ), + key=lambda x: x[1], + ) else: fst = self.create_word_fst(graphemes) if not fst: return [] hypotheses = self.rewrite(fst) - return [x for x in hypotheses if x] + return [x for x in hypotheses if x[0]] -class PhonetisaurusRewriter: +class PhonetisaurusRewriter(Rewriter): """ Helper function for rewriting @@ -224,29 +265,12 @@ def __init__( graphemes: Set[str] = None, strict: bool = False, ): - self.fst = fst + super().__init__( + fst, grapheme_symbol_table, phone_symbol_table, num_pronunciations, threshold, strict + ) self.sequence_separator = sequence_separator - self.grapheme_symbol_table = grapheme_symbol_table - self.phone_symbol_table = phone_symbol_table self.grapheme_order = grapheme_order self.graphemes = graphemes - self.strict = strict - if num_pronunciations > 0: - self.rewrite = functools.partial( - rewrite.top_rewrites, - nshortest=num_pronunciations, - rule=fst, - input_token_type=None, - output_token_type=self.phone_symbol_table, - ) - else: - self.rewrite = functools.partial( - optimal_rewrites, - threshold=threshold, - rule=fst, - input_token_type=None, - output_token_type=self.phone_symbol_table, - ) def create_word_fst(self, word: str) -> typing.Optional[pynini.Fst]: if self.graphemes is not None: @@ -278,22 +302,8 @@ def create_word_fst(self, word: str) -> typing.Optional[pynini.Fst]: def __call__(self, graphemes: str) -> List[str]: # pragma: no cover """Call the rewrite function""" - if " " in graphemes: - words = graphemes.split() - hypotheses = [] - for w in words: - w_fst = self.create_word_fst(w) - if not w_fst: - continue - hypotheses.append(self.rewrite(w_fst)) - hypotheses = sorted(set(" ".join(x) for x in itertools.product(*hypotheses))) - else: - fst = self.create_word_fst(graphemes) - if not fst: - return [] - hypotheses = self.rewrite(fst) - hypotheses = [x.replace(self.sequence_separator, " ") for x in hypotheses if x] - return hypotheses + hypotheses = super().__call__(graphemes) + return [(x.replace(self.sequence_separator, " "), score) for x, score in hypotheses if x] class RewriterWorker(mp.Process): @@ -318,20 +328,24 @@ def __init__( return_queue: mp.Queue, rewriter: Rewriter, stopped: mp.Event, + finished_adding: mp.Event, ): super().__init__() self.job_queue = job_queue self.return_queue = return_queue self.rewriter = rewriter self.stopped = stopped + self.finished_adding = finished_adding self.finished = mp.Event() def run(self) -> None: """Run the rewriting function""" while True: try: - word = self.job_queue.get(timeout=1) + word = self.job_queue.get(timeout=0.01) except queue.Empty: + if not self.finished_adding.is_set(): + continue break if self.stopped.is_set(): continue @@ -356,8 +370,8 @@ class G2PArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run tree_path: :class:`~pathlib.Path` @@ -431,19 +445,19 @@ class OrthographyGenerator(G2PTopLevelMixin): For top level G2P generation parameters """ - def generate_pronunciations(self) -> Dict[str, List[str]]: + def generate_pronunciations( + self, + ) -> typing.Generator[typing.Tuple[str, List[typing.Tuple[str, float]]]]: """ Generate pronunciations for the word set - Returns + Yields ------- - dict[str, Word] - Mapping of words to their "pronunciation" + str, list[tuple[str, float]]] + Tuple of word with their scored pronunciations """ - pronunciations = {} for word in self.words_to_g2p: - pronunciations[word] = [" ".join(word)] - return pronunciations + yield word, [(" ".join(word), 0.0)] class PyniniGenerator(G2PTopLevelMixin): @@ -538,14 +552,16 @@ def setup(self): graphemes=self.g2p_model.meta["graphemes"], ) - def generate_pronunciations(self) -> Dict[str, List[str]]: + def generate_pronunciations( + self, + ) -> typing.Generator[typing.Tuple[str, List[typing.Tuple[str, float]]]]: """ Generate pronunciations - Returns + Yields ------- - dict[str, list[str]] - Mappings of keys to their generated pronunciations + str, list[tuple[str, float]]] + Tuple of word with their scored pronunciations """ num_words = len(self.words_to_g2p) @@ -554,8 +570,9 @@ def generate_pronunciations(self) -> Dict[str, List[str]]: if self.rewriter is None: self.setup() logger.info("Generating pronunciations...") - to_return = {} skipped_words = 0 + min_score = None + max_score = None if not config.USE_MP or num_words < 30 or config.NUM_JOBS == 1: with tqdm(total=num_words, disable=config.QUIET) as pbar: for word in self.words_to_g2p: @@ -572,14 +589,34 @@ def generate_pronunciations(self) -> Dict[str, List[str]]: prons = self.rewriter(w) except rewrite.Error: continue - to_return[word] = prons + scores = [x[1] for x in prons] + if min_score is not None: + scores.append(min_score) + if max_score is not None: + scores.append(max_score) + min_score = min(scores) + max_score = min(scores) + yield w, prons logger.debug( f"Skipping {skipped_words} words for containing the following graphemes: " f"{comma_join(sorted(missing_graphemes))}" ) else: stopped = mp.Event() + finished_adding = mp.Event() job_queue = mp.Queue() + return_queue = mp.Queue() + procs = [] + for _ in range(config.NUM_JOBS): + p = RewriterWorker( + job_queue, + return_queue, + self.rewriter, + stopped, + finished_adding, + ) + procs.append(p) + p.start() for word in self.words_to_g2p: w, m = clean_up_word(word, self.g2p_model.meta["graphemes"]) missing_graphemes = missing_graphemes | m @@ -594,18 +631,8 @@ def generate_pronunciations(self) -> Dict[str, List[str]]: f"Skipping {skipped_words} words for containing the following graphemes: " f"{comma_join(sorted(missing_graphemes))}" ) + finished_adding.set() error_dict = {} - return_queue = mp.Queue() - procs = [] - for _ in range(config.NUM_JOBS): - p = RewriterWorker( - job_queue, - return_queue, - self.rewriter, - stopped, - ) - procs.append(p) - p.start() num_words -= skipped_words with tqdm(total=num_words, disable=config.QUIET) as pbar: while True: @@ -624,14 +651,22 @@ def generate_pronunciations(self) -> Dict[str, List[str]]: if isinstance(result, Exception): error_dict[word] = result continue - to_return[word] = result + scores = [x[1] for x in result] + if min_score is not None: + scores.append(min_score) + if max_score is not None: + scores.append(max_score) + min_score = min(scores) + max_score = min(scores) + yield word, result for p in procs: p.join() if error_dict: raise PyniniGenerationError(error_dict) + logger.debug(f"Minimum score: {min_score}") + logger.debug(f"Maximum score: {max_score}") logger.debug(f"Processed {num_words} in {time.time() - begin:.3f} seconds") - return to_return class PyniniConsoleGenerator(PyniniGenerator): @@ -712,7 +747,7 @@ def compute_validation_errors( total_length = 0 # Since the edit distance algorithm is quadratic, let's do this with # multiprocessing. - logger.debug(f"Processing results for {len(hypothesis_values)} hypotheses") + logger.debug("Processing results for hypotheses") to_comp = [] indices = [] hyp_pron_count = 0 @@ -816,8 +851,8 @@ def evaluate_g2p_model(self, gold_pronunciations: Dict[str, Set[str]]) -> None: gold_pronunciations: dict[str, set[str]] Gold pronunciations """ - output = self.generate_pronunciations() - self.compute_validation_errors(gold_pronunciations, output) + hypotheses = {w: [x[0] for x in prons] for w, prons in self.generate_pronunciations()} + self.compute_validation_errors(gold_pronunciations, hypotheses) class PyniniWordListGenerator(PyniniValidator, DatabaseMixin): @@ -937,7 +972,7 @@ def export_file_pronunciations(self, output_file_path: Path): self.setup() logger.info("Generating pronunciations...") update_mapping = [] - for utt_id, pronunciation in run_kaldi_function( + for utt_id, (pronunciation, _) in run_kaldi_function( G2PFunction, self.g2p_arguments(), total_count=self.num_utterances ): update_mapping.append({"id": utt_id, "transcription_text": pronunciation}) @@ -1004,11 +1039,18 @@ def words_to_g2p(self) -> List[str]: word_list = [x for x in word_list if not self.check_bracketed(x)] return word_list - def export_pronunciations(self, output_file_path: typing.Union[str, Path]) -> None: + def export_pronunciations( + self, + output_file_path: typing.Union[str, Path], + export_scores: bool = False, + ensure_sorted: bool = False, + ) -> None: if self.per_utterance: self.export_file_pronunciations(output_file_path) else: - super().export_pronunciations(output_file_path) + super().export_pronunciations( + output_file_path, export_scores=export_scores, ensure_sorted=ensure_sorted + ) class PyniniDictionaryCorpusGenerator( diff --git a/montreal_forced_aligner/g2p/mixins.py b/montreal_forced_aligner/g2p/mixins.py index 3d19ccb9..288b61f5 100644 --- a/montreal_forced_aligner/g2p/mixins.py +++ b/montreal_forced_aligner/g2p/mixins.py @@ -70,18 +70,25 @@ class G2PTopLevelMixin(MfaWorker, DictionaryMixin, G2PMixin): def __init__(self, **kwargs): super().__init__(**kwargs) - def generate_pronunciations(self) -> Dict[str, List[str]]: + def generate_pronunciations( + self, + ) -> typing.List[typing.Tuple[str, List[typing.Tuple[str, float]]]]: """ Generate pronunciations Returns ------- - dict[str, list[str]] - Mappings of keys to their generated pronunciations + str, list[tuple[str, float]]] + Tuple of word with their scored pronunciations """ raise NotImplementedError - def export_pronunciations(self, output_file_path: typing.Union[str, Path]) -> None: + def export_pronunciations( + self, + output_file_path: typing.Union[str, Path], + export_scores: bool = False, + ensure_sorted: bool = False, + ) -> None: """ Output pronunciations to text file @@ -89,16 +96,26 @@ def export_pronunciations(self, output_file_path: typing.Union[str, Path]) -> No ---------- output_file_path: :class:`~pathlib.Path` Path to save + export_scores: bool + Flag for appending a column for the score of the pronunciation + ensure_sorted: bool + Flag for ensuring that output file is sorted alphabetically """ if isinstance(output_file_path, str): output_file_path = Path(output_file_path) output_file_path.parent.mkdir(parents=True, exist_ok=True) results = self.generate_pronunciations() + if ensure_sorted: + results = sorted(results, key=lambda x: x[0]) with mfa_open(output_file_path, "w") as f: - for (orthography, pronunciations) in results.items(): + for orthography, pronunciations in results: if not pronunciations: continue - for p in pronunciations: + for p, score in pronunciations: if not p: continue - f.write(f"{orthography}\t{p}\n") + if export_scores: + f.write(f"{orthography}\t{p}\t{score}\n") + else: + f.write(f"{orthography}\t{p}\n") + f.flush() diff --git a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py index 3fb42847..c0c466ed 100644 --- a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py +++ b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py @@ -1663,15 +1663,17 @@ def evaluate_g2p_model(self) -> None: num_pronunciations=self.num_pronunciations, ) output = gen.generate_pronunciations() + hypotheses = {} with mfa_open(temp_dir.joinpath("validation_output.txt"), "w") as f: - for orthography, pronunciations in output.items(): + for orthography, pronunciations in output: + hypotheses[orthography] = [x[0] for x in pronunciations] if not pronunciations: continue - for p in pronunciations: + for p, score in pronunciations: if not p: continue - f.write(f"{orthography}\t{p}\n") - gen.compute_validation_errors(validation_set, output) + f.write(f"{orthography}\t{p}\t{score}\n") + gen.compute_validation_errors(validation_set, hypotheses) def initialize_training(self) -> None: """Initialize training G2P model""" diff --git a/montreal_forced_aligner/language_modeling/multiprocessing.py b/montreal_forced_aligner/language_modeling/multiprocessing.py index 93f08a34..f2103fb2 100644 --- a/montreal_forced_aligner/language_modeling/multiprocessing.py +++ b/montreal_forced_aligner/language_modeling/multiprocessing.py @@ -36,8 +36,8 @@ class TrainSpeakerLmArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run model_path: :class:`~pathlib.Path` @@ -70,16 +70,18 @@ class TrainLmArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Working directory symbols_path: :class:`~pathlib.Path` Words symbol table paths - oov_word: str - OOV word order: int Ngram order of the language models + oov_word: str + OOV word """ working_directory: Path @@ -294,7 +296,6 @@ def _run(self) -> None: .distinct() ) for (speaker_id,) in speakers: - print(speaker_id) hclg_path = d.temp_directory.joinpath(f"{speaker_id}.fst") if os.path.exists(hclg_path): continue diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py index b8c5cba3..01014b3c 100644 --- a/montreal_forced_aligner/models.py +++ b/montreal_forced_aligner/models.py @@ -1652,13 +1652,20 @@ class ModelManager: ---------- token: str, optional GitHub authentication token to use to increase release limits + hf_token: str, optional + HuggingFace authentication token to use to increase release limits ignore_cache: bool Flag to ignore previously downloaded files """ base_url = "https://api.github.com/repos/MontrealCorpusTools/mfa-models/releases" - def __init__(self, token: typing.Optional[str] = None, ignore_cache: bool = False): + def __init__( + self, + token: typing.Optional[str] = None, + hf_token: typing.Optional[str] = None, + ignore_cache: bool = False, + ): from montreal_forced_aligner.config import get_temporary_directory pretrained_dir = get_temporary_directory().joinpath("pretrained_models") @@ -1671,6 +1678,10 @@ def __init__(self, token: typing.Optional[str] = None, ignore_cache: bool = Fals environment_token = os.environ.get("GITHUB_TOKEN", None) if self.token is None: self.token = environment_token + self.hf_token = hf_token + environment_token = os.environ.get("HF_TOKEN", None) + if self.hf_token is None: + self.hf_token = environment_token self.synced_remote = False self.ignore_cache = ignore_cache self._cache_info = {} diff --git a/montreal_forced_aligner/online/transcription.py b/montreal_forced_aligner/online/transcription.py index dd08550d..633f8580 100644 --- a/montreal_forced_aligner/online/transcription.py +++ b/montreal_forced_aligner/online/transcription.py @@ -1,6 +1,9 @@ """Classes for calculating alignments online""" from __future__ import annotations +import typing + +import torch from _kalpy.fstext import ConstFst from _kalpy.matrix import DoubleMatrix, FloatMatrix from kalpy.feat.cmvn import CmvnComputer @@ -9,8 +12,16 @@ from kalpy.gmm.decode import GmmDecoder from kalpy.utterance import Utterance as KalpyUtterance +from montreal_forced_aligner.data import Language from montreal_forced_aligner.exceptions import AlignerError from montreal_forced_aligner.models import AcousticModel +from montreal_forced_aligner.tokenization.simple import SimpleTokenizer +from montreal_forced_aligner.transcription.multiprocessing import ( + EncoderASR, + WhisperASR, + WhisperModel, + get_suppressed_tokens, +) def transcribe_utterance_online( @@ -82,3 +93,44 @@ def transcribe_utterance_online( ctm.likelihood = alignment.likelihood ctm.update_utterance_boundaries(utterance.segment.begin, utterance.segment.end) return ctm + + +def transcribe_utterance_online_whisper( + model: WhisperModel, + utterance: KalpyUtterance, + beam: int = 5, + language: Language = Language.unknown, + tokenizer: SimpleTokenizer = None, +) -> str: + segment = utterance.segment + waveform = segment.load_audio() + suppressed = get_suppressed_tokens(model) + segments, info = model.transcribe( + waveform, + language=language.iso_code, + beam_size=beam, + suppress_tokens=suppressed, + temperature=0.0, + condition_on_previous_text=False, + ) + text = " ".join([x.text for x in segments]) + text = text.replace(" ", " ") + if tokenizer is not None: + text = tokenizer(text)[0] + return text.strip() + + +def transcribe_utterance_online_speechbrain( + model: typing.Union[WhisperASR, EncoderASR], + utterance: KalpyUtterance, + tokenizer: SimpleTokenizer = None, +) -> str: + segment = utterance.segment + waveform = segment.load_audio() + waveform = model.audio_normalizer(waveform, 16000).unsqueeze(0) + lens = torch.tensor([1.0]) + predicted_words, predicted_tokens = model.transcribe_batch(waveform, lens) + text = predicted_words[0] + if tokenizer is not None: + text = tokenizer(text)[0] + return text diff --git a/montreal_forced_aligner/tokenization/english.py b/montreal_forced_aligner/tokenization/english.py index c5880fd7..280b5124 100644 --- a/montreal_forced_aligner/tokenization/english.py +++ b/montreal_forced_aligner/tokenization/english.py @@ -373,11 +373,6 @@ def __call__(self, doc: Doc): continue except KeyError: continue - lemma = w.lemma_ - norm = w.norm_ - morph = str(w.morph) - pos = w.pos_ - print(w.text, lemma, norm, morph, pos, w.is_oov) span = None if "Prog" in w.morph.get("Aspect") and w.text.endswith("ing"): span = self.handle_ing(w) @@ -423,7 +418,6 @@ def __call__(self, doc: Doc): break else: break - print(span) if span is not None: with doc.retokenize() as retokenizer: if len(span) == 4: @@ -460,65 +454,65 @@ def __call__(self, doc): def en_spacy(ignore_case: bool = True): name = "en_core_web_sm" try: - nlp = spacy.load(name) + en_nlp = spacy.load(name) except OSError: subprocess.call(["python", "-m", "spacy", "download", name], env=os.environ) - nlp = spacy.load(name) + en_nlp = spacy.load(name) @spacy.Language.factory("en_re_tokenize") - def en_re_tokenize(_nlp, name): - return EnglishReTokenize(_nlp.vocab) + def en_re_tokenize(nlp, name): + return EnglishReTokenize(nlp.vocab) @spacy.Language.factory("en_split_suffixes") - def en_split_suffixes(_nlp, name): - return EnglishSplitSuffixes(_nlp.vocab) + def en_split_suffixes(nlp, name): + return EnglishSplitSuffixes(nlp.vocab) @spacy.Language.factory("en_split_prefixes") - def en_split_prefixes(_nlp, name): - return EnglishSplitPrefixes(_nlp.vocab) + def en_split_prefixes(nlp, name): + return EnglishSplitPrefixes(nlp.vocab) @spacy.Language.factory("en_bracketed_re_tokenize") - def bracketed_re_tokenize(_nlp, name): - return BracketedReTokenize(_nlp.vocab) + def en_bracketed_re_tokenize(nlp, name): + return BracketedReTokenize(nlp.vocab) initial_brackets = r"\(\[\{<" final_brackets = r"\)\]\}>" - nlp.tokenizer.token_match = re.compile( + en_nlp.tokenizer.token_match = re.compile( rf"[{initial_brackets}][-\w_']+[?!,][{final_brackets}]" ).match - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "wanna", [{ORTH: "wan", NORM: "want"}, {ORTH: "na", NORM: "to"}] ) - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "dunno", [{ORTH: "dun", NORM: "don't"}, {ORTH: "no", NORM: "know"}] ) - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "woulda", [{ORTH: "would", NORM: "would"}, {ORTH: "a", NORM: "have"}] ) - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "sorta", [{ORTH: "sort", NORM: "sort"}, {ORTH: "a", NORM: "of"}] ) - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "kinda", [{ORTH: "kind", NORM: "kind"}, {ORTH: "a", NORM: "of"}] ) - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "coulda", [{ORTH: "could", NORM: "could"}, {ORTH: "a", NORM: "have"}] ) - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "shoulda", [{ORTH: "should", NORM: "should"}, {ORTH: "a", NORM: "have"}] ) - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "finna", [{ORTH: "fin", NORM: "fixing"}, {ORTH: "na", NORM: "to"}] ) - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "yknow", [{ORTH: "y", NORM: "you"}, {ORTH: "know", NORM: "know"}] ) - nlp.tokenizer.add_special_case( + en_nlp.tokenizer.add_special_case( "y'know", [{ORTH: "y'", NORM: "you"}, {ORTH: "know", NORM: "know"}] ) - nlp.add_pipe("en_re_tokenize", before="tagger") - nlp.add_pipe("bracketed_re_tokenize", before="tagger") - nlp.add_pipe("en_split_prefixes") - nlp.add_pipe("en_split_suffixes") - return nlp + en_nlp.add_pipe("en_re_tokenize", before="tagger") + en_nlp.add_pipe("en_bracketed_re_tokenize", before="tagger") + en_nlp.add_pipe("en_split_prefixes") + en_nlp.add_pipe("en_split_suffixes") + return en_nlp diff --git a/montreal_forced_aligner/tokenization/simple.py b/montreal_forced_aligner/tokenization/simple.py index 95a880d1..cbb0bca5 100644 --- a/montreal_forced_aligner/tokenization/simple.py +++ b/montreal_forced_aligner/tokenization/simple.py @@ -45,6 +45,7 @@ def __init__( word_break_regex: typing.Optional[re.Pattern], bracket_regex: typing.Optional[re.Pattern], bracket_sanitize_regex: typing.Optional[re.Pattern], + cutoff_regex: typing.Optional[re.Pattern], ignore_case: bool = True, ): self.word_table = word_table @@ -53,6 +54,7 @@ def __init__( self.clitic_quote_regex = clitic_quote_regex self.punctuation_regex = punctuation_regex self.word_break_regex = word_break_regex + self.cutoff_regex = cutoff_regex self.bracket_regex = bracket_regex self.bracket_sanitize_regex = bracket_sanitize_regex @@ -78,6 +80,8 @@ def __call__(self, text) -> typing.Generator[str]: if self.bracket_regex: for word_object in self.bracket_regex.finditer(text): word = word_object.group(0) + if self.cutoff_regex is not None and self.cutoff_regex.match(word): + continue if self.word_table and self.word_table.member(word): continue new_word = self.bracket_sanitize_regex.sub("_", word) @@ -175,9 +179,9 @@ def to_str(self, normalized_text: str) -> str: return self.oov_word if self.word_table and self.word_table.member(normalized_text): return normalized_text + if self.cutoff_regex is not None and self.cutoff_regex.match(normalized_text): + return normalized_text for word, regex in self.non_speech_regexes.items(): - if self.cutoff_regex.match(normalized_text): - return normalized_text if regex.match(normalized_text): return word return normalized_text @@ -300,8 +304,8 @@ def __call__( """ if self.word_table and self.word_table.member(item): return [item] - if self.cutoff_regex.match(item): - return item + if self.cutoff_regex is not None and self.cutoff_regex.match(item): + return [item] for regex in self.non_speech_regexes.values(): if regex.match(item): return [item] @@ -377,6 +381,7 @@ def __init__( self.word_break_regex, self.bracket_regex, self.bracket_sanitize_regex, + self.cutoff_regex, self.ignore_case, ) self.split_function = SplitWordsFunction( diff --git a/montreal_forced_aligner/transcription/multiprocessing.py b/montreal_forced_aligner/transcription/multiprocessing.py index 9bf09ac3..34d52427 100644 --- a/montreal_forced_aligner/transcription/multiprocessing.py +++ b/montreal_forced_aligner/transcription/multiprocessing.py @@ -5,7 +5,11 @@ """ from __future__ import annotations +import logging import os +import queue +import re +import threading import typing from pathlib import Path from typing import TYPE_CHECKING, Dict @@ -14,7 +18,7 @@ from _kalpy.lat import CompactLatticeWriter from _kalpy.lm import ConstArpaLm from _kalpy.util import BaseFloatMatrixWriter, Int32VectorWriter, ReadKaldiObject -from kalpy.data import KaldiMapping, MatrixArchive +from kalpy.data import KaldiMapping, MatrixArchive, Segment from kalpy.decoder.decode_graph import DecodeGraphCompiler from kalpy.feat.data import FeatureArchive from kalpy.feat.fmllr import FmllrComputer @@ -25,9 +29,13 @@ from kalpy.utils import generate_write_specifier from sqlalchemy.orm import joinedload, subqueryload +from montreal_forced_aligner import config from montreal_forced_aligner.abc import KaldiFunction, MetaDict -from montreal_forced_aligner.data import MfaArguments, PhoneType -from montreal_forced_aligner.db import Job, Phone, Utterance +from montreal_forced_aligner.data import Language, MfaArguments, PhoneType +from montreal_forced_aligner.db import File, Job, Phone, SoundFile, Utterance +from montreal_forced_aligner.diarization.multiprocessing import UtteranceFileLoader +from montreal_forced_aligner.helper import mfa_open +from montreal_forced_aligner.tokenization.simple import SimpleTokenizer from montreal_forced_aligner.utils import thread_logger if TYPE_CHECKING: @@ -35,6 +43,43 @@ else: from dataclassy import dataclass +try: + from faster_whisper import WhisperModel + + FOUND_WHISPER = True +except ImportError: + WhisperModel = None + FOUND_WHISPER = False + +try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + torch_logger = logging.getLogger("speechbrain.utils.torch_audio_backend") + torch_logger.setLevel(logging.ERROR) + torch_logger = logging.getLogger("speechbrain.utils.train_logger") + torch_logger.setLevel(logging.ERROR) + transformers_logger = logging.getLogger("transformers.modeling_utils") + transformers_logger.setLevel(logging.ERROR) + transformers_logger = logging.getLogger( + "speechbrain.lobes.models.huggingface_transformers.huggingface" + ) + transformers_logger.setLevel(logging.ERROR) + transformers_logger = logging.getLogger("kenlm") + transformers_logger.setLevel(logging.ERROR) + import torch + + try: + from speechbrain.pretrained import EncoderASR, WhisperASR + except ImportError: # speechbrain 1.0 + from speechbrain.inference.ASR import EncoderASR, WhisperASR + FOUND_SPEECHBRAIN = True +except (ImportError, OSError): + FOUND_SPEECHBRAIN = False + WhisperASR = None + EncoderASR = None + __all__ = [ "FmllrRescoreFunction", @@ -44,6 +89,17 @@ "DecodeFunction", "LmRescoreFunction", "CreateHclgFunction", + "FOUND_SPEECHBRAIN", + "FOUND_WHISPER", + "WhisperModel", + "WhisperASR", + "EncoderASR", + "SpeechbrainAsrArguments", + "SpeechbrainAsrCudaArguments", + "WhisperArguments", + "WhisperCudaArguments", + "SpeechbrainAsrFunction", + "WhisperAsrFunction", ] @@ -56,16 +112,12 @@ class CreateHclgArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run working_directory: :class:`~pathlib.Path` Current working directory - words_path: :class:`~pathlib.Path` - Path to words symbol table - carpa_path: :class:`~pathlib.Path` - Path to .carpa file small_arpa_path: :class:`~pathlib.Path` Path to small ARPA file medium_arpa_path: :class:`~pathlib.Path` @@ -74,14 +126,8 @@ class CreateHclgArguments(MfaArguments): Path to big ARPA file model_path: :class:`~pathlib.Path` Acoustic model path - disambig_L_path: :class:`~pathlib.Path` - Path to disambiguated lexicon file - disambig_int_path: :class:`~pathlib.Path` - Path to disambiguation symbol integer file hclg_options: dict[str, Any] HCLG options - words_mapping: dict[str, int] - Words mapping """ lexicon_compiler: LexiconCompiler @@ -94,6 +140,101 @@ class CreateHclgArguments(MfaArguments): hclg_options: MetaDict +@dataclass +class SpeechbrainAsrArguments(MfaArguments): + """ + Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction` + + Parameters + ---------- + job_name: int + Integer ID of the job + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections + log_path: :class:`~pathlib.Path` + Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Current working directory + """ + + working_directory: Path + architecture: str + language: Language + tokenizer: typing.Optional[SimpleTokenizer] + + +@dataclass +class SpeechbrainAsrCudaArguments(MfaArguments): + """ + Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction` + + Parameters + ---------- + job_name: int + Integer ID of the job + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections + log_path: :class:`~pathlib.Path` + Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Current working directory + """ + + working_directory: Path + model: typing.Union[EncoderASR, WhisperASR] + tokenizer: typing.Optional[SimpleTokenizer] + + +@dataclass +class WhisperArguments(MfaArguments): + """ + Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction` + + Parameters + ---------- + job_name: int + Integer ID of the job + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections + log_path: :class:`~pathlib.Path` + Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Current working directory + """ + + working_directory: Path + model_size: str + language: Language + decode_options: MetaDict + tokenizer: typing.Optional[SimpleTokenizer] + cuda: bool + + +@dataclass +class WhisperCudaArguments(MfaArguments): + """ + Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction` + + Parameters + ---------- + job_name: int + Integer ID of the job + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections + log_path: :class:`~pathlib.Path` + Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Current working directory + """ + + working_directory: Path + model: WhisperModel + language: Language + decode_options: MetaDict + tokenizer: typing.Optional[SimpleTokenizer] + cuda: bool + + @dataclass class DecodeArguments(MfaArguments): """ @@ -103,22 +244,16 @@ class DecodeArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run - dictionaries: list[int] - List of dictionary ids - feature_strings: dict[int, str] - Mapping of dictionaries to feature generation strings - decode_options: dict[str, Any] - Decoding options + working_directory: :class:`~pathlib.Path` + Working directory model_path: :class:`~pathlib.Path` Path to model file - lat_paths: dict[int, Path] - Per dictionary lattice paths - word_symbol_paths: dict[int, Path] - Per dictionary word symbol table paths + decode_options: dict[str, Any] + Decoding options hclg_paths: dict[int, Path] Per dictionary HCLG.fst paths """ @@ -138,10 +273,12 @@ class DecodePhoneArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Working directory dictionaries: list[int] List of dictionary ids feature_strings: dict[int, str] @@ -173,10 +310,12 @@ class LmRescoreArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Working directory dictionaries: list[int] List of dictionary ids lm_rescore_options: dict[str, Any] @@ -206,10 +345,12 @@ class CarpaLmRescoreArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Working directory dictionaries: list[int] List of dictionary ids lat_paths: dict[int, Path] @@ -237,10 +378,12 @@ class InitialFmllrArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Working directory dictionaries: list[int] List of dictionary ids feature_strings: dict[int, str] @@ -272,8 +415,8 @@ class FinalFmllrArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run working_directory: :class:`~pathlib.Path` @@ -299,22 +442,15 @@ class FmllrRescoreArguments(MfaArguments): ---------- job_name: int Integer ID of the job - db_string: str - String for database connections + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections log_path: :class:`~pathlib.Path` Path to save logging information during the run - dictionaries: list[int] - List of dictionary ids - feature_strings: dict[int, str] - Mapping of dictionaries to feature generation strings - model_path: :class:`~pathlib.Path` + working_directory: :class:`~pathlib.Path` + Working directory Path to model file - fmllr_options: dict[str, Any] - fMLLR options - tmp_lat_paths: dict[int, Path] - Per dictionary temporary lattice paths - final_lat_paths: dict[int, Path] - Per dictionary lattice paths + rescore_options: dict[str, Any] + Rescoring options """ working_directory: Path @@ -454,6 +590,241 @@ def _run(self) -> None: ) +class SpeechbrainAsrFunction(KaldiFunction): + """ + Multiprocessing function for performing decoding + + See Also + -------- + :meth:`.TranscriberMixin.transcribe_utterances` + Main function that calls this function in parallel + :meth:`.TranscriberMixin.decode_arguments` + Job method for generating arguments for this function + :kaldi_src:`gmm-latgen-faster` + Relevant Kaldi binary + + Parameters + ---------- + args: :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments` + Arguments for the function + """ + + def __init__(self, args: typing.Union[SpeechbrainAsrArguments, SpeechbrainAsrCudaArguments]): + super().__init__(args) + self.working_directory = args.working_directory + self.cuda = isinstance(args, SpeechbrainAsrCudaArguments) + self.model = None + self.tokenizer = args.tokenizer + if self.cuda: + self.model = args.model + else: + self.model = ( + f"speechbrain/asr-{args.architecture}-commonvoice-14-{args.language.iso_code}" + ) + + def _run(self) -> None: + """Run the function""" + run_opts = None + if self.cuda: + run_opts = {"device": "cuda"} + model = self.model + if isinstance(model, str): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if "wav2vec2" in model: + # Download models if needed + model = EncoderASR.from_hparams( + source=model, + savedir=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "EncoderASR", + model, + ), + run_opts=run_opts, + ) + else: + # Download models if needed + model = WhisperASR.from_hparams( + source=model, + savedir=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "WhisperASR", + model, + ), + run_opts=run_opts, + ) + + return_q = queue.Queue(2) + finished_adding = threading.Event() + stopped = threading.Event() + loader = UtteranceFileLoader( + self.job_name, + self.session, + return_q, + stopped, + finished_adding, + model=model, + for_xvector=False, + ) + loader.start() + exception = None + current_index = 0 + while True: + try: + batch = return_q.get(timeout=1) + except queue.Empty: + if finished_adding.is_set(): + break + continue + if stopped.is_set(): + continue + if isinstance(batch, Exception): + exception = batch + stopped.set() + continue + + audio, lens = batch.signal + predicted_words, predicted_tokens = model.transcribe_batch(audio, lens) + for i, u_id in enumerate(batch.utterance_id): + text = predicted_words[i] + if self.tokenizer is not None: + text = self.tokenizer(text)[0] + self.callback((int(u_id), text)) + del predicted_words + del predicted_tokens + del audio + del lens + current_index += 1 + if current_index > 10: + torch.cuda.empty_cache() + current_index = 0 + + loader.join() + if exception: + raise exception + + +def get_suppressed_tokens(model: WhisperModel) -> typing.List[int]: + suppressed = [] + i = 32 + roman_numeral_pattern = re.compile(r"x+(vi+|i{2,}|i+v|x+)", flags=re.IGNORECASE) + while True: + token = model.hf_tokenizer.id_to_token(i) + if token is None: + break + if not token.startswith("<|"): + if ( + not token.isalpha() + or re.search(r"\d", token) + or roman_numeral_pattern.search(token) + or re.search(r"[IXV]{2,}", token) + or re.search(r"i{2,}$", token) + or re.search(r"^(Ġ)?x{2,}", token) + or re.search(r"^(Ġ)?vi{2,}", token) + or re.match(r"^(Ġ)?[XV]$", token) + ): + suppressed.append(i) + i += 1 + return suppressed + + +class WhisperAsrFunction(KaldiFunction): + """ + Multiprocessing function for performing decoding + + See Also + -------- + :meth:`.TranscriberMixin.transcribe_utterances` + Main function that calls this function in parallel + :meth:`.TranscriberMixin.decode_arguments` + Job method for generating arguments for this function + :kaldi_src:`gmm-latgen-faster` + Relevant Kaldi binary + + Parameters + ---------- + args: :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments` + Arguments for the function + """ + + def __init__(self, args: typing.Union[WhisperArguments, WhisperCudaArguments]): + super().__init__(args) + self.working_directory = args.working_directory + self.cuda = args.cuda + self.model = None + self.language = args.language + self.decode_options = args.decode_options + if isinstance(args, WhisperCudaArguments): + self.model = args.model + else: + self.model = args.model_size + self.tokenizer = args.tokenizer + + def _run(self) -> None: + """Run the function""" + model = self.model + if isinstance(model, str): + if self.cuda: + run_opts = {"device": "cuda", "compute_type": "int8"} + else: + run_opts = {"device": "cpu"} + model = WhisperModel( + model, + download_root=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "Whisper", + ), + local_files_only=True, + **run_opts, + ) + transcribe_opts = {"language": None, "beam_size": self.decode_options["beam_size"]} + if self.language is not Language.unknown: + transcribe_opts["language"] = self.language.iso_code + suppressed = get_suppressed_tokens(model) + current_index = 0 + with self.session() as session, mfa_open(self.log_path, "w") as log_file: + log_file.write(f"Suppressed: {len(suppressed)}\n") + utterances = ( + session.query( + Utterance.id, + SoundFile.sound_file_path, + Utterance.begin, + Utterance.end, + Utterance.channel, + ) + .join(Utterance.file) + .join(File.sound_file) + .filter(Utterance.job_id == self.job_name) + ) + for u in utterances: + segment = Segment(u[1], u[2], u[3], u[4]) + waveform = segment.load_audio() + log_file.write(f"{u[0]}: {waveform.shape}\n") + segments, info = model.transcribe( + waveform, + condition_on_previous_text=False, + suppress_tokens=suppressed, + temperature=0.0, + **transcribe_opts, + ) + text = " ".join([x.text for x in segments]) + del waveform + del segments + del info + if self.tokenizer is not None: + text = self.tokenizer(text)[0] + self.callback((u[0], text)) + log_file.write(f"{u[0]}: {text}\n") + log_file.flush() + current_index += 1 + if current_index > 50: + torch.cuda.empty_cache() + current_index = 0 + + class LmRescoreFunction(KaldiFunction): """ Multiprocessing function rescore lattices by replacing the small G.fst with the medium G.fst diff --git a/montreal_forced_aligner/transcription/transcriber.py b/montreal_forced_aligner/transcription/transcriber.py index 45a62907..ca14fc62 100644 --- a/montreal_forced_aligner/transcription/transcriber.py +++ b/montreal_forced_aligner/transcription/transcriber.py @@ -11,13 +11,15 @@ import os import shutil import subprocess +import sys import threading import time import typing +import warnings from multiprocessing.pool import ThreadPool from pathlib import Path from queue import Empty, Queue -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional import pywrapfst from _kalpy.fstext import VectorFst @@ -28,10 +30,13 @@ from tqdm.rich import tqdm from montreal_forced_aligner import config -from montreal_forced_aligner.abc import TopLevelMfaWorker +from montreal_forced_aligner.abc import FileExporterMixin, TopLevelMfaWorker from montreal_forced_aligner.alignment.base import CorpusAligner +from montreal_forced_aligner.corpus.acoustic_corpus import AcousticCorpusMixin from montreal_forced_aligner.data import ( + ISO_LANGUAGE_MAPPING, ArpaNgramModel, + Language, TextFileType, TextgridFormats, WorkflowType, @@ -46,7 +51,8 @@ Utterance, bulk_update, ) -from montreal_forced_aligner.exceptions import KaldiProcessingError +from montreal_forced_aligner.dictionary.mixins import DictionaryMixin +from montreal_forced_aligner.exceptions import KaldiProcessingError, ModelError from montreal_forced_aligner.helper import ( load_configuration, mfa_open, @@ -62,6 +68,8 @@ from montreal_forced_aligner.models import AcousticModel, LanguageModel from montreal_forced_aligner.textgrid import construct_output_path from montreal_forced_aligner.transcription.multiprocessing import ( + FOUND_SPEECHBRAIN, + FOUND_WHISPER, CarpaLmRescoreArguments, CarpaLmRescoreFunction, CreateHclgArguments, @@ -70,6 +78,7 @@ DecodeFunction, DecodePhoneArguments, DecodePhoneFunction, + EncoderASR, FinalFmllrArguments, FinalFmllrFunction, FmllrRescoreArguments, @@ -80,6 +89,14 @@ LmRescoreFunction, PerSpeakerDecodeArguments, PerSpeakerDecodeFunction, + SpeechbrainAsrArguments, + SpeechbrainAsrCudaArguments, + SpeechbrainAsrFunction, + WhisperArguments, + WhisperASR, + WhisperAsrFunction, + WhisperCudaArguments, + WhisperModel, ) from montreal_forced_aligner.utils import ( KaldiProcessWorker, @@ -91,12 +108,246 @@ if TYPE_CHECKING: from montreal_forced_aligner.abc import MetaDict -__all__ = ["Transcriber", "TranscriberMixin"] +__all__ = ["Transcriber", "TranscriberMixin", "WhisperTranscriber", "SpeechbrainTranscriber"] logger = logging.getLogger("mfa") -class TranscriberMixin(CorpusAligner): +class TranscriptionEvaluationMixin: + def __init__( + self, + evaluation_mode: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.evaluation_mode = evaluation_mode + self.export_output_directory = None + + def evaluate_transcriptions(self) -> None: + """ + Evaluates the transcripts if there are reference transcripts + + Returns + ------- + float, float + Sentence error rate and word error rate + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + logger.info("Evaluating transcripts...") + ser, wer, cer = self.compute_wer() + logger.info(f"SER: {100 * ser:.2f}%, WER: {100 * wer:.2f}%, CER: {100 * cer:.2f}%") + + def save_transcription_evaluation(self, output_directory: Path) -> None: + """ + Save transcription evaluation to an output directory + + Parameters + ---------- + output_directory: str + Directory to save evaluation + """ + output_path = output_directory.joinpath("transcription_evaluation.csv") + with mfa_open(output_path, "w") as f, self.session() as session: + writer = csv.writer(f) + writer.writerow( + [ + "file", + "speaker", + "begin", + "end", + "duration", + "word_count", + "oov_count", + "gold_transcript", + "hypothesis", + "WER", + "CER", + ] + ) + utterances = ( + session.query( + Speaker.name, + File.name, + Utterance.begin, + Utterance.end, + Utterance.duration, + Utterance.normalized_text, + Utterance.transcription_text, + Utterance.oovs, + Utterance.word_error_rate, + Utterance.character_error_rate, + ) + .join(Utterance.speaker) + .join(Utterance.file) + .filter(Utterance.normalized_text != None) # noqa + .filter(Utterance.normalized_text != "") + ) + + for ( + speaker, + file, + begin, + end, + duration, + text, + transcription_text, + oovs, + word_error_rate, + character_error_rate, + ) in utterances: + word_count = text.count(" ") + 1 + oov_count = oovs.count(" ") + 1 + writer.writerow( + [ + file, + speaker, + begin, + end, + duration, + word_count, + oov_count, + text, + transcription_text, + word_error_rate, + character_error_rate, + ] + ) + + def compute_wer(self) -> typing.Tuple[float, float, float]: + """ + Evaluates the transcripts if there are reference transcripts + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + if not hasattr(self, "db_engine"): + raise Exception("Must be used as part of a class with a database engine") + # Sentence-level measures + incorrect = 0 + total_count = 0 + # Word-level measures + total_word_edits = 0 + total_word_length = 0 + + # Character-level measures + total_character_edits = 0 + total_character_length = 0 + + indices = [] + to_comp = [] + + update_mappings = [] + with self.session() as session: + utterances = session.query(Utterance) + utterances = utterances.filter(Utterance.normalized_text != None) # noqa + utterances = utterances.filter(Utterance.normalized_text != "") + for utt in utterances: + g = utt.normalized_text.split() + total_count += 1 + total_word_length += len(g) + character_length = len("".join(g)) + total_character_length += character_length + + if not utt.transcription_text: + incorrect += 1 + total_word_edits += len(g) + total_character_edits += character_length + update_mappings.append( + {"id": utt.id, "word_error_rate": 1.0, "character_error_rate": 1.0} + ) + continue + + h = utt.transcription_text.split() + if g != h: + indices.append(utt.id) + to_comp.append((g, h)) + incorrect += 1 + else: + update_mappings.append( + {"id": utt.id, "word_error_rate": 0.0, "character_error_rate": 0.0} + ) + + with ThreadPool(config.NUM_JOBS) as pool: + gen = pool.starmap(score_wer, to_comp) + for i, (word_edits, word_length, character_edits, character_length) in enumerate( + gen + ): + utt_id = indices[i] + update_mappings.append( + { + "id": utt_id, + "word_error_rate": word_edits / word_length, + "character_error_rate": character_edits / character_length, + } + ) + total_word_edits += word_edits + total_character_edits += character_edits + + bulk_update(session, Utterance, update_mappings) + session.commit() + ser = incorrect / total_count + wer = total_word_edits / total_word_length + cer = total_character_edits / total_character_length + return ser, wer, cer + + def export_transcriptions(self) -> None: + """Export transcriptions""" + with self.session() as session: + files = session.query(File).options( + selectinload(File.utterances), + selectinload(File.speakers), + joinedload(File.sound_file, innerjoin=True).load_only(SoundFile.duration), + ) + for file in files: + utterance_count = len(file.utterances) + duration = file.sound_file.duration + + if utterance_count == 0: + logger.debug(f"Could not find any utterances for {file.name}") + elif ( + utterance_count == 1 + and file.utterances[0].begin == 0 + and file.utterances[0].end == duration + ): + output_format = "lab" + else: + output_format = TextgridFormats.SHORT_TEXTGRID + output_path = construct_output_path( + file.name, + file.relative_path, + self.export_output_directory, + output_format=output_format, + ) + data = file.construct_transcription_tiers() + if output_format == "lab": + for intervals in data.values(): + with mfa_open(output_path, "w") as f: + f.write(intervals["transcription"][0].label) + else: + tg = textgrid.Textgrid() + tg.minTimestamp = 0 + tg.maxTimestamp = round(duration, 5) + for speaker in file.speakers: + speaker = speaker.name + intervals = data[speaker]["transcription"] + tier = textgrid.IntervalTier( + speaker, + [x.to_tg_interval() for x in intervals], + minT=0, + maxT=round(duration, 5), + ) + + tg.addTier(tier) + tg.save(output_path, includeBlankSpaces=True, format=output_format) + + +class TranscriberMixin(CorpusAligner, TranscriptionEvaluationMixin): """Abstract class for MFA transcribers Parameters @@ -136,7 +387,6 @@ def __init__( first_max_active: int = 2000, language_model_weight: int = 10, word_insertion_penalty: float = 0.5, - evaluation_mode: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -150,7 +400,6 @@ def __init__( self.first_max_active = first_max_active self.language_model_weight = language_model_weight self.word_insertion_penalty = word_insertion_penalty - self.evaluation_mode = evaluation_mode self.alignment_mode = False def train_speaker_lm_arguments( @@ -427,230 +676,56 @@ def transcribe_utterances(self) -> None: """ Transcribe the corpus - See Also - -------- - :func:`~montreal_forced_aligner.transcription.multiprocessing.DecodeFunction` - Multiprocessing helper function for each job - :func:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreFunction` - Multiprocessing helper function for each job - :func:`~montreal_forced_aligner.transcription.multiprocessing.CarpaLmRescoreFunction` - Multiprocessing helper function for each job - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - logger.info("Beginning transcription...") - workflow = self.current_workflow - if workflow.done: - logger.info("Transcription already done, skipping!") - return - try: - if workflow.workflow_type is WorkflowType.transcription: - self.uses_speaker_adaptation = False - - self.decode() - if workflow.workflow_type is WorkflowType.transcription: - logger.info("Performing speaker adjusted transcription...") - self.transcribe_fmllr() - self.lm_rescore() - self.carpa_lm_rescore() - self.collect_alignments() - if self.fine_tune: - self.fine_tune_alignments() - if self.evaluation_mode: - os.makedirs(self.working_log_directory, exist_ok=True) - self.evaluate_transcriptions() - with self.session() as session: - session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( - {"done": True} - ) - session.commit() - except Exception as e: - with self.session() as session: - session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( - {"dirty": True} - ) - session.commit() - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs) - e.update_log_file() - raise - - def evaluate_transcriptions(self) -> Tuple[float, float]: - """ - Evaluates the transcripts if there are reference transcripts - - Returns - ------- - float, float - Sentence error rate and word error rate - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - logger.info("Evaluating transcripts...") - ser, wer, cer = self.compute_wer() - logger.info(f"SER: {100 * ser:.2f}%, WER: {100 * wer:.2f}%, CER: {100 * cer:.2f}%") - - def save_transcription_evaluation(self, output_directory: Path) -> None: - """ - Save transcription evaluation to an output directory - - Parameters - ---------- - output_directory: str - Directory to save evaluation - """ - output_path = output_directory.joinpath("transcription_evaluation.csv") - with mfa_open(output_path, "w") as f, self.session() as session: - writer = csv.writer(f) - writer.writerow( - [ - "file", - "speaker", - "begin", - "end", - "duration", - "word_count", - "oov_count", - "gold_transcript", - "hypothesis", - "WER", - "CER", - ] - ) - utterances = ( - session.query( - Speaker.name, - File.name, - Utterance.begin, - Utterance.end, - Utterance.duration, - Utterance.normalized_text, - Utterance.transcription_text, - Utterance.oovs, - Utterance.word_error_rate, - Utterance.character_error_rate, - ) - .join(Utterance.speaker) - .join(Utterance.file) - .filter(Utterance.normalized_text != None) # noqa - .filter(Utterance.normalized_text != "") - ) - - for ( - speaker, - file, - begin, - end, - duration, - text, - transcription_text, - oovs, - word_error_rate, - character_error_rate, - ) in utterances: - word_count = text.count(" ") + 1 - oov_count = oovs.count(" ") + 1 - writer.writerow( - [ - file, - speaker, - begin, - end, - duration, - word_count, - oov_count, - text, - transcription_text, - word_error_rate, - character_error_rate, - ] - ) - - def compute_wer(self) -> typing.Tuple[float, float, float]: - """ - Evaluates the transcripts if there are reference transcripts - - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - if not hasattr(self, "db_engine"): - raise Exception("Must be used as part of a class with a database engine") - logger.info("Evaluating transcripts...") - # Sentence-level measures - incorrect = 0 - total_count = 0 - # Word-level measures - total_word_edits = 0 - total_word_length = 0 - - # Character-level measures - total_character_edits = 0 - total_character_length = 0 - - indices = [] - to_comp = [] - - update_mappings = [] - with self.session() as session: - utterances = session.query(Utterance) - utterances = utterances.filter(Utterance.normalized_text != None) # noqa - utterances = utterances.filter(Utterance.normalized_text != "") - for utt in utterances: - g = utt.normalized_text.split() - total_count += 1 - total_word_length += len(g) - character_length = len("".join(g)) - total_character_length += character_length - - if not utt.transcription_text: - incorrect += 1 - total_word_edits += len(g) - total_character_edits += character_length - update_mappings.append( - {"id": utt.id, "word_error_rate": 1.0, "character_error_rate": 1.0} - ) - continue - - h = utt.transcription_text.split() - if g != h: - indices.append(utt.id) - to_comp.append((g, h)) - incorrect += 1 - else: - update_mappings.append( - {"id": utt.id, "word_error_rate": 0.0, "character_error_rate": 0.0} - ) - - with ThreadPool(config.NUM_JOBS) as pool: - gen = pool.starmap(score_wer, to_comp) - for i, (word_edits, word_length, character_edits, character_length) in enumerate( - gen - ): - utt_id = indices[i] - update_mappings.append( - { - "id": utt_id, - "word_error_rate": word_edits / word_length, - "character_error_rate": character_edits / character_length, - } - ) - total_word_edits += word_edits - total_character_edits += character_edits + See Also + -------- + :func:`~montreal_forced_aligner.transcription.multiprocessing.DecodeFunction` + Multiprocessing helper function for each job + :func:`~montreal_forced_aligner.transcription.multiprocessing.LmRescoreFunction` + Multiprocessing helper function for each job + :func:`~montreal_forced_aligner.transcription.multiprocessing.CarpaLmRescoreFunction` + Multiprocessing helper function for each job - bulk_update(session, Utterance, update_mappings) - session.commit() - ser = incorrect / total_count - wer = total_word_edits / total_word_length - cer = total_character_edits / total_character_length - return ser, wer, cer + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + logger.info("Beginning transcription...") + workflow = self.current_workflow + if workflow.done: + logger.info("Transcription already done, skipping!") + return + try: + if workflow.workflow_type is WorkflowType.transcription: + self.uses_speaker_adaptation = False + + self.decode() + if workflow.workflow_type is WorkflowType.transcription: + logger.info("Performing speaker adjusted transcription...") + self.transcribe_fmllr() + self.lm_rescore() + self.carpa_lm_rescore() + self.collect_alignments() + if self.fine_tune: + self.fine_tune_alignments() + if self.evaluation_mode: + os.makedirs(self.working_log_directory, exist_ok=True) + self.evaluate_transcriptions() + with self.session() as session: + session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( + {"done": True} + ) + session.commit() + except Exception as e: + with self.session() as session: + session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( + {"dirty": True} + ) + session.commit() + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs) + e.update_log_file() + raise @property def transcribe_fmllr_options(self) -> MetaDict: @@ -965,7 +1040,7 @@ class Transcriber(TranscriberMixin, TopLevelMfaWorker): acoustic_model_path : str Path to acoustic model language_model_path : str - Path to language model model + Path to language model evaluation_mode: bool Flag for evaluating generated transcripts against the actual transcripts, defaults to False @@ -1199,58 +1274,6 @@ def setup(self) -> None: self.initialized = True logger.debug(f"Setup for transcription in {time.time() - begin:.3f} seconds") - def export_transcriptions(self) -> None: - """Export transcriptions""" - with self.session() as session: - files = session.query(File).options( - selectinload(File.utterances), - selectinload(File.speakers), - joinedload(File.sound_file, innerjoin=True).load_only(SoundFile.duration), - ) - for file in files: - utterance_count = len(file.utterances) - duration = file.sound_file.duration - - if utterance_count == 0: - logger.debug(f"Could not find any utterances for {file.name}") - elif ( - utterance_count == 1 - and file.utterances[0].begin == 0 - and file.utterances[0].end == duration - ): - output_format = "lab" - else: - output_format = TextgridFormats.SHORT_TEXTGRID - output_path = construct_output_path( - file.name, - file.relative_path, - self.export_output_directory, - output_format=output_format, - ) - data = file.construct_transcription_tiers() - if output_format == "lab": - for intervals in data.values(): - with mfa_open(output_path, "w") as f: - f.write(intervals["transcription"][0].label) - else: - tg = textgrid.Textgrid() - tg.minTimestamp = 0 - tg.maxTimestamp = round(duration, 5) - for speaker in file.speakers: - speaker = speaker.name - intervals = data[speaker]["transcription"] - tier = textgrid.IntervalTier( - speaker, - [x.to_tg_interval() for x in intervals], - minT=0, - maxT=round(duration, 5), - ) - - tg.addTier(tier) - tg.save(output_path, includeBlankSpaces=True, format=output_format) - if self.evaluation_mode: - self.save_transcription_evaluation(self.export_output_directory) - def export_files( self, output_directory: Path, @@ -1275,3 +1298,318 @@ def export_files( self.export_transcriptions() else: self.export_textgrids(output_format, include_original_text) + if self.evaluation_mode: + self.save_transcription_evaluation(self.export_output_directory) + + +class HuggingFaceTranscriber( + AcousticCorpusMixin, + TranscriptionEvaluationMixin, + FileExporterMixin, + TopLevelMfaWorker, + DictionaryMixin, +): + def __init__( + self, + language: typing.Union[str, Language] = Language.unknown, + cuda: bool = False, + evaluation_mode: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.cuda = cuda + self.evaluation_mode = evaluation_mode + if not isinstance(language, Language): + language = Language[language] + self.language = language + self.transcription_function = None + + def get_tokenizers(self): + return self.tokenizer + + def setup(self) -> None: + self.initialize_database() + self._load_corpus() + self.initialize_jobs() + if self.evaluation_mode: + self._create_dummy_dictionary() + self.normalize_text() + self.create_new_current_workflow(WorkflowType.transcription) + wf = self.current_workflow + if wf.done: + logger.info("Transcription already done, skipping initialization.") + return + log_dir = self.working_directory.joinpath("log") + os.makedirs(log_dir, exist_ok=True) + + def transcribe_arguments(self) -> typing.List: + return [] + + def transcribe(self): + self.setup() + self.transcribe_utterances() + + def transcribe_utterances(self) -> None: + """ + Transcribe the corpus + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + logger.info("Beginning transcription...") + workflow = self.current_workflow + if workflow.done: + logger.info("Transcription already done, skipping!") + return + try: + if workflow.workflow_type is WorkflowType.transcription: + self.uses_speaker_adaptation = False + + arguments = self.transcribe_arguments() + if self.cuda: + config.update_configuration( + { + "USE_THREADING": True, + # "USE_MP": False, + } + ) + update_mapping = [] + with self.session() as session: + for u_id, transcript in run_kaldi_function( + self.transcription_function, arguments, total_count=self.num_utterances + ): + update_mapping.append({"id": u_id, "transcription_text": transcript}) + if update_mapping: + bulk_update(session, Utterance, update_mapping) + session.commit() + if self.evaluation_mode: + os.makedirs(self.working_log_directory, exist_ok=True) + self.evaluate_transcriptions() + with self.session() as session: + session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( + {"done": True} + ) + session.commit() + except Exception as e: + with self.session() as session: + session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( + {"dirty": True} + ) + session.commit() + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs) + e.update_log_file() + raise + + def export_files( + self, + output_directory: Path, + ) -> None: + """ + Export transcriptions + + Parameters + ---------- + output_directory: str + Directory to save transcriptions + """ + self.export_output_directory = output_directory + os.makedirs(self.export_output_directory, exist_ok=True) + self.export_transcriptions() + if self.evaluation_mode: + self.save_transcription_evaluation(self.export_output_directory) + + +class WhisperTranscriber(HuggingFaceTranscriber): + ARCHITECTURES = ["distil-large-v3", "medium", "large-v3", "base", "tiny", "small"] + + def __init__(self, architecture: str = "distil-large-v3", **kwargs): + if not FOUND_WHISPER: + logger.error( + "Could not import faster_whisper, please ensure it is installed via `pip install faster-whisper`" + ) + sys.exit(1) + if architecture not in self.ARCHITECTURES: + raise ModelError( + f"The architecture {architecture} is not in: {', '.join(self.ARCHITECTURES)}" + ) + super().__init__(**kwargs) + self.architecture = architecture + self.model = None + self.transcription_function = WhisperAsrFunction + + def transcribe_arguments(self): + if self.cuda: + return [ + WhisperCudaArguments( + j.id, + getattr(self, "session", ""), + self.working_log_directory.joinpath(f"whisper_asr.{j.id}.log"), + self.working_directory, + self.model, + self.language, + {"beam_size": 5}, + self.tokenizer if self.evaluation_mode else None, + self.cuda, + ) + for j in self.jobs + ] + return [ + WhisperArguments( + j.id, + getattr(self, "session" if config.USE_THREADING else "db_string", ""), + self.working_log_directory.joinpath(f"whisper_asr.{j.id}.log"), + self.working_directory, + self.architecture, + self.language, + {"beam_size": 5}, + self.tokenizer if self.evaluation_mode else None, + self.cuda, + ) + for j in self.jobs + ] + + # noinspection PyTypeChecker + def setup(self) -> None: + """ + Sets up the corpus and speaker classifier + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + if self.initialized: + return + iso_code = self.language.iso_code + if iso_code is None: + raise ModelError( + f"The language {self.language.name} not in {', '.join(sorted(ISO_LANGUAGE_MAPPING.keys()))}" + ) + try: + if self.cuda: + run_opts = {"device": "cuda", "compute_type": "int8"} + self.model = WhisperModel( + self.architecture, + download_root=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "Whisper", + ), + local_files_only=False, + cpu_threads=config.NUM_JOBS, + num_workers=config.NUM_JOBS, + **run_opts, + ) + else: + # Download models if needed + _ = WhisperModel( + self.architecture, + download_root=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "Whisper", + ), + local_files_only=False, + ) + except Exception: + raise ModelError( + f"Could not download whisper model with {self.architecture} and {self.language.name}" + ) + super().setup() + + +class SpeechbrainTranscriber(HuggingFaceTranscriber): + ARCHITECTURES = ["whisper-medium", "wav2vec2", "whisper-large-v2"] + + def __init__(self, architecture: str = "whisper-medium", **kwargs): + if not FOUND_SPEECHBRAIN: + logger.error( + "Could not import faster_whisper, please ensure it is installed via `pip install faster-whisper`" + ) + sys.exit(1) + if architecture not in self.ARCHITECTURES: + raise ModelError( + f"The architecture {architecture} is not in: {', '.join(self.ARCHITECTURES)}" + ) + self.architecture = architecture + super().__init__(**kwargs) + self.model = None + self.transcription_function = SpeechbrainAsrFunction + + def transcribe_arguments(self): + if self.cuda: + return [ + SpeechbrainAsrCudaArguments( + j.id, + getattr(self, "session", ""), + self.working_log_directory.joinpath(f"speechbrain_asr.{j.id}.log"), + self.working_directory, + self.model, + self.tokenizer if self.evaluation_mode else None, + ) + for j in self.jobs + ] + return [ + SpeechbrainAsrArguments( + j.id, + getattr(self, "session" if config.USE_THREADING else "db_string", ""), + self.working_log_directory.joinpath(f"speechbrain_asr.{j.id}.log"), + self.working_directory, + self.architecture, + self.language, + self.tokenizer if self.evaluation_mode else None, + ) + for j in self.jobs + ] + + def setup(self) -> None: + """ + Sets up the corpus and speaker classifier + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + if self.initialized: + return + common_voice_code = self.language.iso_code + if common_voice_code is None: + raise ModelError( + f"The language {self.language.name} not in {', '.join(sorted(ISO_LANGUAGE_MAPPING.keys()))}" + ) + model_key = f"speechbrain/asr-{self.architecture}-commonvoice-14-{common_voice_code}" + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + if self.architecture == "wav2vec2": + # Download models if needed + m = EncoderASR.from_hparams( + source=model_key, + savedir=os.path.join( + config.TEMPORARY_DIRECTORY, "models", "EncoderASR", model_key + ), + ) + else: + # Download models if needed + m = WhisperASR.from_hparams( + source=model_key, + savedir=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "WhisperASR", + model_key, + ), + ) + if self.cuda: + self.model = m + except ImportError: + raise + except Exception: + raise ModelError( + f"Could not download a speechbrain model with {self.architecture} and {self.language.name} ({model_key})" + ) + super().setup() diff --git a/montreal_forced_aligner/utils.py b/montreal_forced_aligner/utils.py index a6a08bcf..d5acb8e0 100644 --- a/montreal_forced_aligner/utils.py +++ b/montreal_forced_aligner/utils.py @@ -179,7 +179,7 @@ def parse_dictionary_file( float or None Correction factor for no silence before the pronunciation """ - prob_pattern = re.compile(r"\b\d+\.\d+\b") + prob_pattern = re.compile(r"\b(\d+\.\d+|1)\b") with mfa_open(path) as f: for i, line in enumerate(f): line = line.strip() @@ -663,11 +663,7 @@ def run_kaldi_function( stopped = Event() error_dict = {} return_queue = Queue(10000) - callback_interval = 100 - if total_count is not None: - callback_interval = int(total_count / 100) - if callback_interval <= 0: - callback_interval = 2 + callback_interval = 10 num_done = 0 last_update = 0 pbar = None @@ -675,7 +671,7 @@ def run_kaldi_function( if not config.QUIET and total_count: pbar = tqdm(total=total_count, maxinterval=0) progress_callback = pbar.update - + update_time = time.time() if config.USE_MP: procs = [] for args in arguments: @@ -699,9 +695,11 @@ def run_kaldi_function( num_done += result else: num_done += 1 - if num_done - last_update > callback_interval: - progress_callback(num_done - last_update) - last_update = num_done + if time.time() - update_time >= callback_interval: + if num_done - last_update > 0: + progress_callback(num_done - last_update) + last_update = num_done + update_time = time.time() if isinstance(return_queue, queue.Queue): return_queue.task_done() except queue.Empty: diff --git a/montreal_forced_aligner/vad/multiprocessing.py b/montreal_forced_aligner/vad/multiprocessing.py index aa6d43c1..1d6c44d4 100644 --- a/montreal_forced_aligner/vad/multiprocessing.py +++ b/montreal_forced_aligner/vad/multiprocessing.py @@ -94,6 +94,48 @@ class SegmentTranscriptArguments(MfaArguments): decode_options: MetaDict +def segment_utterance( + utterance: KalpyUtterance, + vad_model: VAD, + segmentation_options: MetaDict, + mfcc_options: MetaDict = None, + vad_options: MetaDict = None, + allow_empty: bool = True, +) -> typing.List[Segment]: + """ + Split an utterance and its transcript into multiple transcribed utterances + + Parameters + ---------- + utterance: :class:`~kalpy.utterance.Utterance` + Utterance to split + vad_model: :class:`~speechbrain.pretrained.VAD` or None + VAD model from SpeechBrain, if None, then Kaldi's energy-based VAD is used + segmentation_options: dict[str, Any] + Segmentation options + mfcc_options: dict[str, Any], optional + MFCC options for energy based VAD + vad_options: dict[str, Any], optional + Options for energy based VAD + + Returns + ------- + list[:class:`~kalpy.utterance.Utterance`] + Split utterances + """ + if vad_model is None: + segments = segment_utterance_vad( + utterance, mfcc_options, vad_options, segmentation_options + ) + else: + segments = segment_utterance_vad_speech_brain( + utterance, vad_model, segmentation_options, allow_empty=allow_empty + ) + if not segments: + return [utterance.segment] + return segments + + def segment_utterance_transcript( acoustic_model: AcousticModel, utterance: KalpyUtterance, @@ -180,12 +222,9 @@ def segment_utterance_transcript( cmvn_computer = CmvnComputer() cmvn = cmvn_computer.compute_cmvn_from_features([utterance.mfccs]) current_transcript = utterance.transcript - if vad_model is None: - segments = segment_utterance_vad( - utterance, mfcc_options, vad_options, segmentation_options - ) - else: - segments = segment_utterance_vad_speech_brain(utterance, vad_model, segmentation_options) + segments = segment_utterance( + utterance, vad_model, segmentation_options, mfcc_options, vad_options + ) if not segments: return [utterance] config = LatticeFasterDecoderConfig() @@ -381,17 +420,16 @@ def merge_segments( or s.begin > merged_segments[-1].end + min_pause_duration or s.end - merged_segments[-1].begin > max_segment_length ): - if s.end - s.begin > min_pause_duration: - if merged_segments and snap_boundary_threshold: - boundary_gap = s.begin - merged_segments[-1].end - if boundary_gap < snap_boundary_threshold: - half_boundary = boundary_gap / 2 - else: - half_boundary = snap_boundary_threshold / 2 - merged_segments[-1].end += half_boundary - s.begin -= half_boundary - - merged_segments.append(s) + if merged_segments and snap_boundary_threshold: + boundary_gap = s.begin - merged_segments[-1].end + if boundary_gap < snap_boundary_threshold: + half_boundary = boundary_gap / 2 + else: + half_boundary = snap_boundary_threshold / 2 + merged_segments[-1].end += half_boundary + s.begin -= half_boundary + + merged_segments.append(s) else: merged_segments[-1].end = s.end return [x for x in merged_segments if x.end - x.begin > min_segment_length] @@ -402,17 +440,32 @@ def segment_utterance_vad( mfcc_options: MetaDict, vad_options: MetaDict, segmentation_options: MetaDict, + adaptive: bool = True, + allow_empty: bool = True, ) -> typing.List[Segment]: + mfcc_options["use_energy"] = True + mfcc_options["raw_energy"] = False + mfcc_options["dither"] = 0.0 + mfcc_options["energy_floor"] = 0.0 mfcc_computer = MfccComputer(**mfcc_options) - vad_computer = VadComputer(**vad_options) feats = mfcc_computer.compute_mfccs_for_export(utterance.segment, compress=False) + if adaptive: + vad_options["energy_mean_scale"] = 0.0 + mfccs = feats.numpy() + print(mfccs[:, 0]) + min_0, max_0 = mfccs[:, 0].min(), mfccs[:, 0].max() + range = max_0 - min_0 + thresh = (range * 0.6) + min_0 + print("THRESHOLD", thresh, min_0, max_0) + vad_options["energy_threshold"] = mfccs[:, 0].mean() + vad_computer = VadComputer(**vad_options) vad = vad_computer.compute_vad(feats).numpy() segments = get_initial_segmentation(vad, mfcc_computer.frame_shift) segments = merge_segments( segments, segmentation_options["close_th"], segmentation_options["large_chunk_size"], - segmentation_options["len_th"], + segmentation_options["len_th"] if allow_empty else 0.02, ) new_segments = [] for s in segments: @@ -427,7 +480,10 @@ def segment_utterance_vad( def segment_utterance_vad_speech_brain( - utterance: KalpyUtterance, vad_model: VAD, segmentation_options: MetaDict + utterance: KalpyUtterance, + vad_model: VAD, + segmentation_options: MetaDict, + allow_empty: bool = True, ) -> typing.List[Segment]: y = utterance.segment.wave prob_chunks = vad_model.get_speech_prob_chunk( @@ -444,12 +500,14 @@ def segment_utterance_vad_speech_brain( # Apply energy-based VAD on the detected speech segments if segmentation_options["apply_energy_VAD"]: - boundaries = vad_model.energy_VAD( + vad_boundaries = vad_model.energy_VAD( utterance.segment.file_path, boundaries, activation_th=segmentation_options["en_activation_th"], deactivation_th=segmentation_options["en_deactivation_th"], ) + if vad_boundaries.size(0) != 0 or allow_empty: + boundaries = vad_boundaries # Merge short segments boundaries = vad_model.merge_close_segments( @@ -457,13 +515,20 @@ def segment_utterance_vad_speech_brain( ) # Remove short segments - boundaries = vad_model.remove_short_segments(boundaries, len_th=segmentation_options["len_th"]) + filtered_boundaries = vad_model.remove_short_segments( + boundaries, len_th=segmentation_options["len_th"] + ) + if filtered_boundaries.size(0) != 0 or allow_empty: + boundaries = filtered_boundaries # Double check speech segments if segmentation_options["double_check"]: - boundaries = vad_model.double_check_speech_segments( + checked_boundaries = vad_model.double_check_speech_segments( boundaries, utterance.segment.file_path, speech_th=segmentation_options["speech_th"] ) + if checked_boundaries.size(0) != 0 or allow_empty: + boundaries = checked_boundaries + print(boundaries) boundaries[:, 0] -= round(segmentation_options["close_th"] / 2, 3) boundaries[:, 1] += round(segmentation_options["close_th"] / 2, 3) boundaries = boundaries.numpy() diff --git a/montreal_forced_aligner/vad/segmenter.py b/montreal_forced_aligner/vad/segmenter.py index ffb56854..b109f7d5 100644 --- a/montreal_forced_aligner/vad/segmenter.py +++ b/montreal_forced_aligner/vad/segmenter.py @@ -44,6 +44,7 @@ SegmentTranscriptFunction, SegmentVadArguments, SegmentVadFunction, + segment_utterance, segment_utterance_transcript, ) @@ -444,6 +445,20 @@ def export_files(self, output_directory: str, output_format: Optional[str] = Non ): f.save(output_directory, output_format=output_format) + def segment_utterance(self, utterance_id: int, allow_empty: bool = True): + with self.session() as session: + utterance = full_load_utterance(session, utterance_id) + + new_utterances = segment_utterance( + utterance.to_kalpy(), + self.vad_model if self.speechbrain else None, + self.segmentation_options, + mfcc_options=self.mfcc_options if not self.speechbrain else None, + vad_options=self.vad_options if not self.speechbrain else None, + allow_empty=allow_empty, + ) + return new_utterances + class TranscriptionSegmenter( VadConfigMixin, TranscriberMixin, SpeechbrainSegmenterMixin, TopLevelMfaWorker diff --git a/montreal_forced_aligner/validation/corpus_validator.py b/montreal_forced_aligner/validation/corpus_validator.py index dbf28c1c..fc6bf8d3 100644 --- a/montreal_forced_aligner/validation/corpus_validator.py +++ b/montreal_forced_aligner/validation/corpus_validator.py @@ -74,9 +74,14 @@ def working_log_directory(self) -> str: """Working log directory""" return self.working_directory.joinpath("log") - def analyze_setup(self) -> None: + def analyze_setup(self, output_directory: Path = None) -> None: """ Analyzes the setup process and outputs info to the console + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ begin = time.time() @@ -120,16 +125,23 @@ def analyze_setup(self) -> None: self.analyze_textgrid_read_errors() logger.info("Dictionary") - self.analyze_oovs() + self.analyze_oovs(output_directory=output_directory) - def analyze_oovs(self) -> None: + def analyze_oovs(self, output_directory: Path = None) -> None: """ Analyzes OOVs in the corpus and constructs message + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ logger.info("Out of vocabulary words") - output_dir = self.output_directory - oov_path = os.path.join(output_dir, "oovs_found.txt") - utterance_oov_path = os.path.join(output_dir, "utterance_oovs.txt") + if output_directory is None: + output_directory = self.output_directory + os.makedirs(output_directory, exist_ok=True) + oov_path = os.path.join(output_directory, "oovs_found.txt") + utterance_oov_path = os.path.join(output_directory, "utterance_oovs.txt") total_instances = 0 with mfa_open(utterance_oov_path, "w") as f, self.session() as session: @@ -155,7 +167,7 @@ def analyze_oovs(self) -> None: ) self.oovs_found.update(oovs) if self.oovs_found: - self.save_oovs_found(self.output_directory) + self.save_oovs_found(output_directory) logger.warning(f"{len(self.oovs_found)} OOV word types") logger.warning(f"{total_instances}total OOV tokens") logger.warning( @@ -169,16 +181,23 @@ def analyze_oovs(self) -> None: "least some missing words." ) - def analyze_wav_errors(self) -> None: + def analyze_wav_errors(self, output_directory: Path = None) -> None: """ Analyzes any sound file issues in the corpus and constructs message + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ logger.info("Sound file read errors") - output_dir = self.output_directory + if output_directory is None: + output_directory = self.output_directory + os.makedirs(output_directory, exist_ok=True) wav_read_errors = self.sound_file_errors if wav_read_errors: - path = os.path.join(output_dir, "sound_file_errors.csv") + path = os.path.join(output_directory, "sound_file_errors.csv") with mfa_open(path, "w") as f: for p in wav_read_errors: f.write(f"{p}\n") @@ -190,15 +209,23 @@ def analyze_wav_errors(self) -> None: else: logger.info("There were no issues reading sound files.") - def analyze_missing_features(self) -> None: + def analyze_missing_features(self, output_directory: Path = None) -> None: """ Analyzes issues in feature generation in the corpus and constructs message + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ logger.info("Feature generation") if self.ignore_acoustics: logger.info("Acoustic feature generation was skipped.") return - output_dir = self.output_directory + + if output_directory is None: + output_directory = self.output_directory + os.makedirs(output_directory, exist_ok=True) with self.session() as session: utterances = ( session.query(File.name, File.relative_path, Utterance.begin, Utterance.end) @@ -206,10 +233,9 @@ def analyze_missing_features(self) -> None: .filter(Utterance.ignored == True) # noqa ) if utterances.count(): - path = os.path.join(output_dir, "missing_features.csv") + path = os.path.join(output_directory, "missing_features.csv") with mfa_open(path, "w") as f: for file_name, relative_path, begin, end in utterances: - f.write(f"{relative_path.joinpath(file_name)},{begin},{end}\n") logger.error( @@ -219,15 +245,22 @@ def analyze_missing_features(self) -> None: else: logger.info("There were no utterances missing features.") - def analyze_files_with_no_transcription(self) -> None: + def analyze_files_with_no_transcription(self, output_directory: Path = None) -> None: """ Analyzes issues with sound files that have no transcription files in the corpus and constructs message + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ logger.info("Files without transcriptions") - output_dir = self.output_directory + if output_directory is None: + output_directory = self.output_directory + os.makedirs(output_directory, exist_ok=True) if self.no_transcription_files: - path = os.path.join(output_dir, "missing_transcriptions.csv") + path = os.path.join(output_directory, "missing_transcriptions.csv") with mfa_open(path, "w") as f: for file_path in self.no_transcription_files: f.write(f"{file_path}\n") @@ -238,15 +271,22 @@ def analyze_files_with_no_transcription(self) -> None: else: logger.info("There were no sound files missing transcriptions.") - def analyze_transcriptions_with_no_wavs(self) -> None: + def analyze_transcriptions_with_no_wavs(self, output_directory: Path = None) -> None: """ Analyzes issues with transcription that have no sound files in the corpus and constructs message + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ logger.info("Transcriptions without sound files") - output_dir = self.output_directory + if output_directory is None: + output_directory = self.output_directory + os.makedirs(output_directory, exist_ok=True) if self.transcriptions_without_wavs: - path = os.path.join(output_dir, "transcriptions_missing_sound_files.csv") + path = os.path.join(output_directory, "transcriptions_missing_sound_files.csv") with mfa_open(path, "w") as f: for file_path in self.transcriptions_without_wavs: f.write(f"{file_path}\n") @@ -257,15 +297,22 @@ def analyze_transcriptions_with_no_wavs(self) -> None: else: logger.info("There were no transcription files missing sound files.") - def analyze_textgrid_read_errors(self) -> None: + def analyze_textgrid_read_errors(self, output_directory: Path = None) -> None: """ Analyzes issues with reading TextGrid files in the corpus and constructs message + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ logger.info("TextGrid read errors") - output_dir = self.output_directory + if output_directory is None: + output_directory = self.output_directory + os.makedirs(output_directory, exist_ok=True) if self.textgrid_read_errors: - path = os.path.join(output_dir, "textgrid_read_errors.txt") + path = os.path.join(output_directory, "textgrid_read_errors.txt") with mfa_open(path, "w") as f: for e in self.textgrid_read_errors: f.write( @@ -278,15 +325,22 @@ def analyze_textgrid_read_errors(self) -> None: else: logger.info("There were no issues reading TextGrids.") - def analyze_unreadable_text_files(self) -> None: + def analyze_unreadable_text_files(self, output_directory: Path = None) -> None: """ Analyzes issues with reading text files in the corpus and constructs message + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ logger.info("Text file read errors") - output_dir = self.output_directory + if output_directory is None: + output_directory = self.output_directory + os.makedirs(output_directory, exist_ok=True) if self.decode_error_files: - path = os.path.join(output_dir, "utf8_read_errors.csv") + path = os.path.join(output_directory, "utf8_read_errors.csv") with mfa_open(path, "w") as f: for file_path in self.decode_error_files: f.write(f"{file_path}\n") @@ -297,17 +351,25 @@ def analyze_unreadable_text_files(self) -> None: else: logger.info("There were no issues reading text files.") - def test_utterance_transcriptions(self) -> None: + def test_utterance_transcriptions(self, output_directory: Path = None) -> None: """ Tests utterance transcriptions with simple unigram models based on the utterance text and frequent words in the corpus + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in + Raises ------ :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` If there were any errors in running Kaldi binaries """ + if output_directory is None: + output_directory = self.output_directory + os.makedirs(output_directory, exist_ok=True) try: self.train_speaker_lms() @@ -336,8 +398,8 @@ def test_utterance_transcriptions(self) -> None: else: logger.error(f"{cer*100:.2f}% character error rate") - self.save_transcription_evaluation(self.output_directory) - out_path = os.path.join(self.output_directory, "transcription_evaluation.csv") + self.save_transcription_evaluation(output_directory) + out_path = os.path.join(output_directory, "transcription_evaluation.csv") logger.info(f"See {out_path} for more details.") except Exception as e: @@ -389,7 +451,6 @@ def parse_parameters( args: Optional[Dict[str, Any]] = None, unknown_args: Optional[typing.Iterable[str]] = None, ) -> MetaDict: - """ Parse parameters for validation from a config path or command-line arguments @@ -469,8 +530,6 @@ def setup(self) -> None: self.normalize_text() - self.save_oovs_found(self.output_directory) - begin = time.time() self.write_lexicon_information() self.write_training_information() @@ -485,7 +544,6 @@ def setup(self) -> None: self.generate_features() logger.debug(f"Generated features in {time.time() - begin:.3f} seconds") begin = time.time() - self.save_oovs_found(self.output_directory) logger.debug(f"Calculated OOVs in {time.time() - begin:.3f} seconds") self.setup_trainers() @@ -496,14 +554,19 @@ def setup(self) -> None: e.update_log_file() raise - def validate(self) -> None: + def validate(self, output_directory: Path = None) -> None: """ Performs validation of the corpus + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ begin = time.time() logger.debug(f"Setup took {time.time() - begin:.3f} seconds") self.setup() - self.analyze_setup() + self.analyze_setup(output_directory=output_directory) logger.debug(f"Setup took {time.time() - begin:.3f} seconds") if self.ignore_acoustics: logger.info("Skipping test alignments.") @@ -511,7 +574,7 @@ def validate(self) -> None: logger.info("Training") self.train() if self.test_transcriptions: - self.test_utterance_transcriptions() + self.test_utterance_transcriptions(output_directory=output_directory) self.get_phone_confidences() @@ -551,8 +614,6 @@ def setup(self) -> None: self.initialize_jobs() self.normalize_text() - self.save_oovs_found(self.output_directory) - if self.ignore_acoustics: logger.info("Skipping acoustic feature generation") else: @@ -572,14 +633,19 @@ def setup(self) -> None: e.update_log_file() raise - def validate(self) -> None: + def validate(self, output_directory: Path = None) -> None: """ Performs validation of the corpus + + Parameters + ---------- + output_directory: Path, optional + Optional directory to save output files in """ self.initialize_database() self.create_new_current_workflow(WorkflowType.alignment) self.setup() - self.analyze_setup() + self.analyze_setup(output_directory=output_directory) self.analyze_missing_phones() if self.ignore_acoustics: logger.info("Skipping test alignments.") @@ -594,7 +660,7 @@ def validate(self) -> None: self.transcribe() self.collect_alignments() if self.test_transcriptions: - self.test_utterance_transcriptions() + self.test_utterance_transcriptions(output_directory=output_directory) self.collect_alignments() self.transcription_done = True with self.session() as session: diff --git a/tests/test_commandline_train.py b/tests/test_commandline_train.py index 9e63b749..f4046047 100644 --- a/tests/test_commandline_train.py +++ b/tests/test_commandline_train.py @@ -115,6 +115,7 @@ def test_train_and_align_basic_speaker_dict( english_mfa_rules_path, "--topology_path", english_mfa_topology_path, + "--use_postgres", ] command = [str(x) for x in command] result = click.testing.CliRunner(mix_stderr=False).invoke( diff --git a/tests/test_commandline_transcribe.py b/tests/test_commandline_transcribe.py index 7fd3bfb8..2e4ce4ab 100644 --- a/tests/test_commandline_transcribe.py +++ b/tests/test_commandline_transcribe.py @@ -1,8 +1,10 @@ import os import click.testing +import pytest from montreal_forced_aligner.command_line.mfa import mfa_cli +from montreal_forced_aligner.transcription.multiprocessing import FOUND_SPEECHBRAIN, FOUND_WHISPER def test_transcribe( @@ -46,6 +48,81 @@ def test_transcribe( assert os.path.exists(os.path.join(output_path, "michael", "acoustic_corpus.lab")) +def test_transcribe_speechbrain( + combined_corpus_dir, + generated_dir, + transcription_acoustic_model, + transcription_language_model, + temp_dir, + db_setup, +): + if not FOUND_SPEECHBRAIN: + pytest.skip("SpeechBrain not installed") + output_path = generated_dir.joinpath("transcribe_test_sb") + command = [ + "transcribe_speechbrain", + combined_corpus_dir, + "english", + output_path, + "--architecture", + "wav2vec2", + "--clean", + "--no_debug", + "--evaluate", + "--no_cuda", + "--use_postgres", + ] + command = [str(x) for x in command] + result = click.testing.CliRunner(mix_stderr=False).invoke( + mfa_cli, command, catch_exceptions=True + ) + print(result.stdout) + print(result.stderr) + if result.exception: + print(result.exc_info) + raise result.exception + assert not result.return_value + assert os.path.exists(output_path) + + +def test_transcribe_whisper( + combined_corpus_dir, + generated_dir, + transcription_acoustic_model, + transcription_language_model, + temp_dir, + db_setup, +): + if not FOUND_WHISPER: + pytest.skip("Faster-whisper not installed") + output_path = generated_dir.joinpath("transcribe_test_whisper") + command = [ + "transcribe_whisper", + combined_corpus_dir, + output_path, + "--language", + "english", + "--architecture", + "tiny", + "--clean", + "--no_debug", + "--evaluate", + "--no_cuda", + "--use_postgres", + ] + command = [str(x) for x in command] + result = click.testing.CliRunner(mix_stderr=False).invoke( + mfa_cli, command, catch_exceptions=True + ) + print(result.stdout) + print(result.stderr) + if result.exception: + print(result.exc_info) + raise result.exception + assert not result.return_value + assert os.path.exists(output_path) + + def test_transcribe_arpa( basic_corpus_dir, english_dictionary, diff --git a/tests/test_g2p.py b/tests/test_g2p.py index 559375ca..99dfd3a5 100644 --- a/tests/test_g2p.py +++ b/tests/test_g2p.py @@ -87,5 +87,7 @@ def test_generator_pretrained(english_g2p_model, temp_dir, db_setup): ) gen.setup() results = gen.generate_pronunciations() - assert len(results["petted"]) == 3 + for word, prons in results: + if word == "petted": + assert len(prons) == 3 gen.cleanup() From 53f3974d498797b664f452f1c61b69e3a9231582 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Wed, 4 Sep 2024 11:35:02 -0700 Subject: [PATCH 02/16] Whisper with hf conversion --- .pre-commit-config.yaml | 2 - environment.yml | 2 + .../command_line/anchor.py | 11 +- .../transcription/multiprocessing.py | 347 ++++++++++++++++-- .../transcription/transcriber.py | 202 ++++++++-- tests/test_commandline_transcribe.py | 4 +- 6 files changed, 498 insertions(+), 70 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 91711b0f..84efa958 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,3 @@ -default_language_version: - python: python3.11 repos: - repo: https://github.com/psf/black rev: 23.9.1 diff --git a/environment.yml b/environment.yml index e8256b47..124e364a 100644 --- a/environment.yml +++ b/environment.yml @@ -51,6 +51,8 @@ dependencies: - sudachipy - sudachidict-core - spacy-pkuseg + - transformers + - tokenizers - pip: - build - twine diff --git a/montreal_forced_aligner/command_line/anchor.py b/montreal_forced_aligner/command_line/anchor.py index 5e99b506..f043382c 100644 --- a/montreal_forced_aligner/command_line/anchor.py +++ b/montreal_forced_aligner/command_line/anchor.py @@ -2,7 +2,6 @@ from __future__ import annotations import logging -import sys import requests import rich_click as click @@ -20,14 +19,8 @@ def anchor_cli(*args, **kwargs) -> None: # pragma: no cover """ Launch Anchor Annotator (if installed) """ - try: - from anchor.command_line import main - except ImportError as e: - logger.error(f"Exception: {e}") - logger.error( - "Anchor annotator utility is not installed, please install it via `conda install -c conda-forge anchor-annotator`." - ) - sys.exit(1) + from anchor.command_line import main # noqa + if config.VERBOSE: try: from anchor._version import version diff --git a/montreal_forced_aligner/transcription/multiprocessing.py b/montreal_forced_aligner/transcription/multiprocessing.py index 34d52427..ea16b0b2 100644 --- a/montreal_forced_aligner/transcription/multiprocessing.py +++ b/montreal_forced_aligner/transcription/multiprocessing.py @@ -14,6 +14,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Dict +import numpy as np +import sqlalchemy from _kalpy.fstext import ConstFst, VectorFst from _kalpy.lat import CompactLatticeWriter from _kalpy.lm import ConstArpaLm @@ -34,9 +36,8 @@ from montreal_forced_aligner.data import Language, MfaArguments, PhoneType from montreal_forced_aligner.db import File, Job, Phone, SoundFile, Utterance from montreal_forced_aligner.diarization.multiprocessing import UtteranceFileLoader -from montreal_forced_aligner.helper import mfa_open from montreal_forced_aligner.tokenization.simple import SimpleTokenizer -from montreal_forced_aligner.utils import thread_logger +from montreal_forced_aligner.utils import mfa_open, thread_logger if TYPE_CHECKING: from dataclasses import dataclass @@ -51,6 +52,15 @@ WhisperModel = None FOUND_WHISPER = False +try: + from transformers import WhisperForConditionalGeneration, WhisperProcessor + + FOUND_TRANSFORMERS = True +except ImportError: + WhisperForConditionalGeneration = None + WhisperProcessor = None + FOUND_TRANSFORMERS = False + try: import warnings @@ -91,7 +101,7 @@ "CreateHclgFunction", "FOUND_SPEECHBRAIN", "FOUND_WHISPER", - "WhisperModel", + "WhisperForConditionalGeneration", "WhisperASR", "EncoderASR", "SpeechbrainAsrArguments", @@ -203,7 +213,7 @@ class WhisperArguments(MfaArguments): """ working_directory: Path - model_size: str + model_id: str language: Language decode_options: MetaDict tokenizer: typing.Optional[SimpleTokenizer] @@ -227,6 +237,58 @@ class WhisperCudaArguments(MfaArguments): Current working directory """ + working_directory: Path + model_id: str + model: WhisperForConditionalGeneration + processor: WhisperProcessor + language: Language + decode_options: MetaDict + tokenizer: typing.Optional[SimpleTokenizer] + cuda: bool + + +@dataclass +class FasterWhisperArguments(MfaArguments): + """ + Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction` + + Parameters + ---------- + job_name: int + Integer ID of the job + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections + log_path: :class:`~pathlib.Path` + Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Current working directory + """ + + working_directory: Path + model_size: str + language: Language + decode_options: MetaDict + tokenizer: typing.Optional[SimpleTokenizer] + cuda: bool + + +@dataclass +class FasterWhisperCudaArguments(MfaArguments): + """ + Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction` + + Parameters + ---------- + job_name: int + Integer ID of the job + session: :class:`sqlalchemy.orm.scoped_session` or str + SqlAlchemy scoped session or string for database connections + log_path: :class:`~pathlib.Path` + Path to save logging information during the run + working_directory: :class:`~pathlib.Path` + Current working directory + """ + working_directory: Path model: WhisperModel language: Language @@ -706,31 +768,141 @@ def _run(self) -> None: raise exception -def get_suppressed_tokens(model: WhisperModel) -> typing.List[int]: +class WhisperUtteranceLoader(threading.Thread): + """ + Helper process for loading utterance waveforms in parallel with embedding extraction + + Parameters + ---------- + job_name: int + Job identifier + session: sqlalchemy.orm.scoped_session + Session + return_q: :class:`~queue.Queue` + Queue to put waveforms + stopped: :class:`~threading.Event` + Check for whether the process to exit gracefully + finished_adding: :class:`~threading.Event` + Check for whether the worker has processed all utterances + """ + + def __init__( + self, + job_name: int, + session: sqlalchemy.orm.scoped_session, + return_q: queue.Queue, + stopped: threading.Event, + finished_adding: threading.Event, + processor: WhisperProcessor, + ): + super().__init__() + self.job_name = job_name + self.session = session + self.return_q = return_q + self.stopped = stopped + self.finished_adding = finished_adding + self.processor = processor + + def run(self) -> None: + """ + Run the waveform loading job + """ + + batch_size = config.NUM_JOBS + + with self.session() as session: + try: + utterances = ( + session.query( + Utterance.id, + SoundFile.sound_file_path, + Utterance.begin, + Utterance.end, + Utterance.channel, + ) + .join(Utterance.file) + .join(File.sound_file) + .filter(Utterance.duration <= 30) + .order_by(Utterance.duration.desc()) + ) + if not utterances.count(): + self.finished_adding.set() + return + raw_audio = [] + utterance_ids = [] + for u in utterances: + if self.stopped.is_set(): + break + utterance_ids.append(u[0]) + segment = Segment(u[1], u[2], u[3], u[4]) + audio = segment.load_audio().astype(np.float32) + raw_audio.append(audio) + if len(utterance_ids) >= batch_size: + inputs = self.processor( + raw_audio, + return_tensors="pt", + truncation=False, + return_attention_mask=True, + sampling_rate=16_000, + ) + self.return_q.put((utterance_ids, inputs)) + raw_audio = [] + utterance_ids = [] + if utterance_ids: + inputs = self.processor( + raw_audio, + return_tensors="pt", + truncation=False, + return_attention_mask=True, + sampling_rate=16_000, + ) + self.return_q.put((utterance_ids, inputs)) + except Exception as e: + self.return_q.put(e) + finally: + self.finished_adding.set() + + +def get_suppressed_tokens( + whisper_processor: typing.Union[WhisperProcessor, WhisperModel] +) -> typing.List[int]: suppressed = [] - i = 32 - roman_numeral_pattern = re.compile(r"x+(vi+|i{2,}|i+v|x+)", flags=re.IGNORECASE) - while True: - token = model.hf_tokenizer.id_to_token(i) - if token is None: - break - if not token.startswith("<|"): - if ( - not token.isalpha() - or re.search(r"\d", token) - or roman_numeral_pattern.search(token) - or re.search(r"[IXV]{2,}", token) - or re.search(r"i{2,}$", token) - or re.search(r"^(Ġ)?x{2,}", token) - or re.search(r"^(Ġ)?vi{2,}", token) - or re.match(r"^(Ġ)?[XV]$", token) - ): - suppressed.append(i) - i += 1 + roman_numeral_pattern = re.compile(r"(x+(vi+|i{2,}|i+v|x+))", flags=re.IGNORECASE) + case_roman_numeral_pattern = re.compile(r"([IXV]{2,}|i{2,}$|^.?x{2,}|^.?vi{2,}|\d)") + if isinstance(whisper_processor, WhisperProcessor): + for token_id, token in whisper_processor.tokenizer.decoder.items(): + if token_id in whisper_processor.tokenizer.all_special_ids: + continue + if not token: + continue + if not token.startswith("<|"): + if ( + not token.isalpha() + or roman_numeral_pattern.search(token) + or case_roman_numeral_pattern.search(token) + or re.match(r"^.?[XV]$", token) + ): + suppressed.append(token_id) + else: + i = 0 + while True: + token = whisper_processor.hf_tokenizer.id_to_token(i) + if token is None: + break + if not token.startswith("<|"): + if ( + not token.isalpha() + or roman_numeral_pattern.search(token) + or case_roman_numeral_pattern.search(token) + or re.match(r"^.?[XV]$", token) + ): + suppressed.append(i) + i += 1 + return suppressed -class WhisperAsrFunction(KaldiFunction): +class FasterWhisperFunction(KaldiFunction): """ Multiprocessing function for performing decoding @@ -749,14 +921,14 @@ class WhisperAsrFunction(KaldiFunction): Arguments for the function """ - def __init__(self, args: typing.Union[WhisperArguments, WhisperCudaArguments]): + def __init__(self, args: typing.Union[FasterWhisperArguments, FasterWhisperCudaArguments]): super().__init__(args) self.working_directory = args.working_directory self.cuda = args.cuda self.model = None self.language = args.language self.decode_options = args.decode_options - if isinstance(args, WhisperCudaArguments): + if isinstance(args, FasterWhisperCudaArguments): self.model = args.model else: self.model = args.model_size @@ -767,7 +939,7 @@ def _run(self) -> None: model = self.model if isinstance(model, str): if self.cuda: - run_opts = {"device": "cuda", "compute_type": "int8"} + run_opts = {"device": "cuda", "compute_type": "float16"} else: run_opts = {"device": "cpu"} model = WhisperModel( @@ -798,6 +970,7 @@ def _run(self) -> None: .join(Utterance.file) .join(File.sound_file) .filter(Utterance.job_id == self.job_name) + .filter(Utterance.duration > 30) ) for u in utterances: segment = Segment(u[1], u[2], u[3], u[4]) @@ -825,6 +998,124 @@ def _run(self) -> None: current_index = 0 +class WhisperAsrFunction(KaldiFunction): + """ + Multiprocessing function for performing decoding + + See Also + -------- + :meth:`.TranscriberMixin.transcribe_utterances` + Main function that calls this function in parallel + :meth:`.TranscriberMixin.decode_arguments` + Job method for generating arguments for this function + :kaldi_src:`gmm-latgen-faster` + Relevant Kaldi binary + + Parameters + ---------- + args: :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments` + Arguments for the function + """ + + def __init__(self, args: typing.Union[WhisperArguments, WhisperCudaArguments]): + super().__init__(args) + self.working_directory = args.working_directory + self.cuda = args.cuda + self.model_id = args.model_id + self.model = None + self.processor = None + self.language = args.language + self.decode_options = args.decode_options + if isinstance(args, WhisperCudaArguments): + self.model = args.model + self.processor = args.processor + self.tokenizer = args.tokenizer + + def _run(self) -> None: + """Run the function""" + processor = self.processor + if processor is None: + processor = WhisperProcessor.from_pretrained(self.model_id) + processor.tokenizer.add_prefix_space = False + language = None + if self.language is not Language.unknown: + language = self.language.iso_code + model = self.model + if model is None: + suppressed = get_suppressed_tokens(processor) + model = WhisperForConditionalGeneration.from_pretrained(self.model_id) + model.generation_config.suppress_tokens += suppressed + model.generation_config.suppress_tokens = list( + set(model.generation_config.suppress_tokens) + ) + model.generation_config.suppress_tokens.sort() + if language is not None: + model.generation_config.forced_decoder_ids = None + if self.cuda: + model.to("cuda") + return_q = queue.Queue(2) + finished_adding = threading.Event() + stopped = threading.Event() + loader = WhisperUtteranceLoader( + self.job_name, + self.session, + return_q, + stopped, + finished_adding, + processor, + ) + loader.start() + exception = None + current_index = 0 + while True: + try: + batch = return_q.get(timeout=1) + except queue.Empty: + if finished_adding.is_set(): + break + continue + if stopped.is_set(): + continue + if isinstance(batch, Exception): + exception = batch + stopped.set() + continue + utterance_ids, inputs = batch + if self.cuda: + inputs = inputs.to("cuda", torch.float16) + result = model.generate( + **inputs, + condition_on_prev_tokens=False, + temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + logprob_threshold=-1.0, + compression_ratio_threshold=1.35, + return_timestamps=False, + language=language, + pad_token_id=processor.tokenizer.eos_token_id, + ) + + decoded = processor.batch_decode( + result, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + for i, u_id in enumerate(utterance_ids): + text = decoded[i] + if self.tokenizer is not None: + text = self.tokenizer(text)[0] + self.callback((int(u_id), text)) + del utterance_ids + del inputs + del result + del decoded + current_index += 1 + if current_index > 10: + torch.cuda.empty_cache() + current_index = 0 + + loader.join() + if exception: + raise exception + + class LmRescoreFunction(KaldiFunction): """ Multiprocessing function rescore lattices by replacing the small G.fst with the medium G.fst diff --git a/montreal_forced_aligner/transcription/transcriber.py b/montreal_forced_aligner/transcription/transcriber.py index ca14fc62..cd0e534e 100644 --- a/montreal_forced_aligner/transcription/transcriber.py +++ b/montreal_forced_aligner/transcription/transcriber.py @@ -7,6 +7,7 @@ import collections import csv +import gc import logging import os import shutil @@ -79,6 +80,9 @@ DecodePhoneArguments, DecodePhoneFunction, EncoderASR, + FasterWhisperArguments, + FasterWhisperCudaArguments, + FasterWhisperFunction, FinalFmllrArguments, FinalFmllrFunction, FmllrRescoreArguments, @@ -96,7 +100,9 @@ WhisperASR, WhisperAsrFunction, WhisperCudaArguments, + WhisperForConditionalGeneration, WhisperModel, + get_suppressed_tokens, ) from montreal_forced_aligner.utils import ( KaldiProcessWorker, @@ -139,7 +145,7 @@ def evaluate_transcriptions(self) -> None: """ logger.info("Evaluating transcripts...") ser, wer, cer = self.compute_wer() - logger.info(f"SER: {100 * ser:.2f}%, WER: {100 * wer:.2f}%, CER: {100 * cer:.2f}%") + logger.info(f"SER: {100 * ser: .2f}%, WER: {100 * wer: .2f}%, CER: {100 * cer: .2f}%") def save_transcription_evaluation(self, output_directory: Path) -> None: """ @@ -443,7 +449,7 @@ def train_speaker_lms(self) -> None: TrainSpeakerLmFunction, arguments, total_count=self.num_speakers ): pass - logger.debug(f"Compiling speaker language models took {time.time() - begin:.3f} seconds") + logger.debug(f"Compiling speaker language models took {time.time() - begin: .3f} seconds") @property def model_directory(self) -> Path: @@ -1272,7 +1278,7 @@ def setup(self) -> None: self.setup_acoustic_model() self.create_decoding_graph() self.initialized = True - logger.debug(f"Setup for transcription in {time.time() - begin:.3f} seconds") + logger.debug(f"Setup for transcription in {time.time() - begin: .3f} seconds") def export_files( self, @@ -1428,7 +1434,7 @@ class WhisperTranscriber(HuggingFaceTranscriber): def __init__(self, architecture: str = "distil-large-v3", **kwargs): if not FOUND_WHISPER: logger.error( - "Could not import faster_whisper, please ensure it is installed via `pip install faster-whisper`" + "Could not import transformers, please ensure it is installed via `conda install transformers`" ) sys.exit(1) if architecture not in self.ARCHITECTURES: @@ -1444,20 +1450,51 @@ def transcribe_arguments(self): if self.cuda: return [ WhisperCudaArguments( - j.id, + 1, getattr(self, "session", ""), - self.working_log_directory.joinpath(f"whisper_asr.{j.id}.log"), + self.working_log_directory.joinpath("whisper_asr.log"), self.working_directory, + f"openai/whisper-{self.architecture}", self.model, + self.processor, self.language, {"beam_size": 5}, self.tokenizer if self.evaluation_mode else None, self.cuda, ) - for j in self.jobs ] return [ WhisperArguments( + j.id, + getattr(self, "session" if config.USE_THREADING else "db_string", ""), + self.working_log_directory.joinpath(f"whisper_asr.{j.id}.log"), + self.working_directory, + config.TEMPORARY_DIRECTORY.joinpath("models", "whisper", self.architecture), + self.language, + {"beam_size": 5}, + self.tokenizer if self.evaluation_mode else None, + self.cuda, + ) + for j in self.jobs + ] + + def faster_whisper_arguments(self): + if self.cuda: + return [ + FasterWhisperCudaArguments( + 1, + getattr(self, "session", ""), + self.working_log_directory.joinpath("whisper_asr.log"), + self.working_directory, + self.model, + self.language, + {"beam_size": 5}, + self.tokenizer if self.evaluation_mode else None, + self.cuda, + ) + ] + return [ + FasterWhisperArguments( j.id, getattr(self, "session" if config.USE_THREADING else "db_string", ""), self.working_log_directory.joinpath(f"whisper_asr.{j.id}.log"), @@ -1471,6 +1508,91 @@ def transcribe_arguments(self): for j in self.jobs ] + def transcribe_utterances(self) -> None: + super().transcribe_utterances() + workflow = self.current_workflow + iso_code = self.language.iso_code + if iso_code is None: + raise ModelError( + f"The language {self.language.name} not in {', '.join(sorted(ISO_LANGUAGE_MAPPING.keys()))}" + ) + try: + arguments = self.faster_whisper_arguments() + update_mapping = [] + with self.session() as session: + num_utterances = session.query(Utterance).filter(Utterance.duration > 30).count() + if not num_utterances: + return + if self.cuda: + import torch + + run_opts = {"device": "cuda", "compute_type": "float16"} + del self.model + gc.collect() + torch.cuda.empty_cache() + self.model = WhisperModel( + self.architecture, + download_root=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "Whisper", + ), + local_files_only=False, + cpu_threads=config.NUM_JOBS, + num_workers=config.NUM_JOBS, + **run_opts, + ) + config.update_configuration( + { + "USE_THREADING": True, + # "USE_MP": False, + } + ) + else: + # Download models if needed + _ = WhisperModel( + self.architecture, + download_root=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "Whisper", + ), + local_files_only=False, + ) + logger.info("Transcribing longer utterances (>30 seconds)...") + for u_id, transcript in run_kaldi_function( + FasterWhisperFunction, arguments, total_count=num_utterances + ): + update_mapping.append({"id": u_id, "transcription_text": transcript}) + if update_mapping: + bulk_update(session, Utterance, update_mapping) + session.commit() + if self.evaluation_mode: + os.makedirs(self.working_log_directory, exist_ok=True) + self.evaluate_transcriptions() + with self.session() as session: + session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( + {"done": True} + ) + session.commit() + except Exception as e: + with self.session() as session: + session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( + {"dirty": True} + ) + session.commit() + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs) + e.update_log_file() + raise + finally: + if self.cuda: + import torch + + del self.model + gc.collect() + torch.cuda.empty_cache() + # noinspection PyTypeChecker def setup(self) -> None: """ @@ -1489,32 +1611,54 @@ def setup(self) -> None: f"The language {self.language.name} not in {', '.join(sorted(ISO_LANGUAGE_MAPPING.keys()))}" ) try: + from transformers import WhisperProcessor + + model_path = config.TEMPORARY_DIRECTORY.joinpath( + "models", "whisper", self.architecture + ) + if not model_path.exists(): + subprocess.call( + [ + "python", + "-m", + "transformers.models.whisper.convert_openai_to_hf", + "--checkpoint_path", + self.architecture, + "--pytorch_dump_folder_path", + str(model_path), + "--convert_preprocessor", + "True", + ] + ) if self.cuda: - run_opts = {"device": "cuda", "compute_type": "int8"} - self.model = WhisperModel( - self.architecture, - download_root=os.path.join( - config.TEMPORARY_DIRECTORY, - "models", - "Whisper", - ), - local_files_only=False, - cpu_threads=config.NUM_JOBS, - num_workers=config.NUM_JOBS, - **run_opts, + import torch + from transformers.utils import is_flash_attn_2_available + + self.processor = WhisperProcessor.from_pretrained(model_path) + suppressed = get_suppressed_tokens(self.processor) + attn_implementation = ( + "flash_attention_2" if is_flash_attn_2_available() else "sdpa" ) - else: - # Download models if needed - _ = WhisperModel( - self.architecture, - download_root=os.path.join( - config.TEMPORARY_DIRECTORY, - "models", - "Whisper", - ), - local_files_only=False, + logger.debug(f"Using {attn_implementation} for attention") + + self.model = WhisperForConditionalGeneration.from_pretrained( + model_path, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation=attn_implementation, + ) + self.model.generation_config.suppress_tokens += suppressed + self.model.generation_config.suppress_tokens = list( + set(self.model.generation_config.suppress_tokens) ) + self.model.generation_config.suppress_tokens.sort() + if self.language.iso_code is not None: + self.model.generation_config.forced_decoder_ids = None + self.model.to("cuda") + else: + _ = WhisperForConditionalGeneration.from_pretrained(model_path) except Exception: + raise raise ModelError( f"Could not download whisper model with {self.architecture} and {self.language.name}" ) diff --git a/tests/test_commandline_transcribe.py b/tests/test_commandline_transcribe.py index 2e4ce4ab..66e25191 100644 --- a/tests/test_commandline_transcribe.py +++ b/tests/test_commandline_transcribe.py @@ -94,7 +94,7 @@ def test_transcribe_whisper( db_setup, ): if not FOUND_WHISPER: - pytest.skip("Faster-whisper not installed") + pytest.skip("transformers not installed") output_path = generated_dir.joinpath("transcribe_test_whisper") command = [ "transcribe_whisper", @@ -107,7 +107,7 @@ def test_transcribe_whisper( "--clean", "--no_debug", "--evaluate", - "--no_cuda", + "--cuda", "--use_postgres", ] command = [str(x) for x in command] From ac817686347f08862af9c3edc5cc8fa17c41b719 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Fri, 6 Sep 2024 10:35:04 -0700 Subject: [PATCH 03/16] Working pipeline for huggingface whisper --- .pre-commit-config.yaml | 16 +- environment.yml | 10 +- montreal_forced_aligner/abc.py | 5 +- .../command_line/transcribe.py | 72 ++++- .../online/transcription.py | 96 +++++- .../transcription/multiprocessing.py | 209 +++++++++---- .../transcription/transcriber.py | 282 +++++++++++------- .../vad/multiprocessing.py | 185 +++++++++--- montreal_forced_aligner/vad/segmenter.py | 39 ++- pyproject.toml | 5 + tests/test_commandline_transcribe.py | 10 +- 11 files changed, 667 insertions(+), 262 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84efa958..a7c50fe9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,24 @@ repos: + - repo: local + hooks: + - id: profile-check + name: no profiling + entry: '@profile' + language: pygrep + types: [ python ] + - id: print-check + name: no print statements + entry: '\bprint\(' + language: pygrep + types: [ python ] + files: ^montreal_forced_aligner/ + exclude: ^montreal_forced_aligner/command_line/transcribe.py - repo: https://github.com/psf/black rev: 23.9.1 hooks: - id: black - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + rev: 7.0.0 hooks: - id: flake8 additional_dependencies: diff --git a/environment.yml b/environment.yml index 124e364a..a4b2efd2 100644 --- a/environment.yml +++ b/environment.yml @@ -51,17 +51,21 @@ dependencies: - sudachipy - sudachidict-core - spacy-pkuseg + # Whisper dependencies - transformers - tokenizers + - accelerate + - tiktoken - pip: - build - twine - - speechbrain - - kenlm - - pygtrie - faster-whisper - python-mecab-ko - jamo - pythainlp - hanziconv - dragonmapper + # Speechbrain dependencies + - speechbrain + - kenlm + - pygtrie diff --git a/montreal_forced_aligner/abc.py b/montreal_forced_aligner/abc.py index d7f345fa..80666584 100644 --- a/montreal_forced_aligner/abc.py +++ b/montreal_forced_aligner/abc.py @@ -283,7 +283,8 @@ def initialize_database(self) -> None: ) except Exception: raise DatabaseError( - f"There was an error connecting to the {config.CURRENT_PROFILE_NAME} MFA database server. " + f"There was an error connecting to the {config.CURRENT_PROFILE_NAME} MFA database server " + f"at {config.database_socket()}. " "Please ensure the server is initialized (mfa server init) or running (mfa server start)" ) exist_check = False @@ -304,7 +305,7 @@ def initialize_database(self) -> None: conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector")) conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS pg_trgm")) conn.execute(sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS pg_stat_statements")) - conn.execute(sqlalchemy.text(f"select setseed({config.SEED/32768})")) + conn.execute(sqlalchemy.text(f"select setseed({config.SEED / 32768})")) conn.commit() MfaSqlBase.metadata.create_all(self.db_engine) diff --git a/montreal_forced_aligner/command_line/transcribe.py b/montreal_forced_aligner/command_line/transcribe.py index b379221f..ddfc2d46 100644 --- a/montreal_forced_aligner/command_line/transcribe.py +++ b/montreal_forced_aligner/command_line/transcribe.py @@ -1,9 +1,11 @@ """Command line functions for transcribing corpora""" from __future__ import annotations +import sys from pathlib import Path import rich_click as click +from kalpy.data import Segment from montreal_forced_aligner import config from montreal_forced_aligner.command_line.utils import ( @@ -13,11 +15,16 @@ validate_language_model, ) from montreal_forced_aligner.data import Language +from montreal_forced_aligner.online.transcription import ( + transcribe_utterance_online_faster_whisper, + transcribe_utterance_online_whisper, +) from montreal_forced_aligner.transcription.transcriber import ( SpeechbrainTranscriber, Transcriber, WhisperTranscriber, ) +from montreal_forced_aligner.utils import mfa_open __all__ = ["transcribe_corpus_cli", "transcribe_speechbrain_cli", "transcribe_whisper_cli"] @@ -251,12 +258,10 @@ def transcribe_speechbrain_cli(context, **kwargs) -> None: short_help="Transcribe utterances using a Whisper ASR model via faster-whisper", ) @click.argument( - "corpus_directory", - type=click.Path(exists=True, file_okay=False, dir_okay=True, path_type=Path), -) -@click.argument( - "output_directory", type=click.Path(file_okay=False, dir_okay=True, path_type=Path) + "input_path", + type=click.Path(exists=True, file_okay=True, dir_okay=True, path_type=Path), ) +@click.argument("output_path", type=click.Path(file_okay=True, dir_okay=True, path_type=Path)) @click.option( "--architecture", help="Model size to use", @@ -302,6 +307,18 @@ def transcribe_speechbrain_cli(context, **kwargs) -> None: help="Evaluate the transcription against golden texts.", default=False, ) +@click.option( + "--vad", + is_flag=True, + help="Use VAD to split utterances.", + default=False, +) +@click.option( + "--incremental", + is_flag=True, + help="Save outputs immediately and use previous progress.", + default=False, +) @common_options @click.help_option("-h", "--help") @click.pass_context @@ -314,16 +331,49 @@ def transcribe_whisper_cli(context, **kwargs) -> None: config.update_configuration(kwargs) config_path = kwargs.get("config_path", None) - corpus_directory = kwargs["corpus_directory"].absolute() - output_directory = kwargs["output_directory"] + incremental = kwargs.get("incremental", False) + input_path: Path = kwargs["input_path"].absolute() + output_path: Path = kwargs["output_path"] + corpus_root = input_path + if not corpus_root.is_dir(): + corpus_root = corpus_root.parent + transcriber = WhisperTranscriber( - corpus_directory=corpus_directory, + corpus_directory=corpus_root, + export_directory=output_path if incremental else None, **WhisperTranscriber.parse_parameters(config_path, context.params, context.args), ) try: - transcriber.setup() - transcriber.transcribe() - transcriber.export_files(output_directory) + if not input_path.is_dir(): + segment = Segment(input_path) + faster_whisper = segment.wave.shape[0] / 16_000 > 30 + faster_whisper = False + transcriber.setup_model(online=True, faster_whisper=faster_whisper) + + if faster_whisper: + text = transcribe_utterance_online_faster_whisper( + transcriber.model, segment, language=transcriber.language + ) + else: + text = transcribe_utterance_online_whisper( + transcriber.model, + transcriber.processor, + segment, + language=transcriber.language, + segmenter=transcriber.segmenter, + ) + if str(output_path) == "-": + print(text) # noqa + sys.exit(0) + output_path.parent.mkdir(parents=True, exist_ok=True) + with mfa_open(output_path, "w") as f: + f.write(text) + del transcriber.model + elif input_path.is_dir(): + transcriber.setup() + transcriber.transcribe() + if not incremental: + transcriber.export_files(output_path) except Exception: transcriber.dirty = True raise diff --git a/montreal_forced_aligner/online/transcription.py b/montreal_forced_aligner/online/transcription.py index 633f8580..208e8b97 100644 --- a/montreal_forced_aligner/online/transcription.py +++ b/montreal_forced_aligner/online/transcription.py @@ -3,9 +3,11 @@ import typing +import numpy as np import torch from _kalpy.fstext import ConstFst from _kalpy.matrix import DoubleMatrix, FloatMatrix +from kalpy.data import Segment from kalpy.feat.cmvn import CmvnComputer from kalpy.fstext.lexicon import LexiconCompiler from kalpy.gmm.data import HierarchicalCtm @@ -17,11 +19,18 @@ from montreal_forced_aligner.models import AcousticModel from montreal_forced_aligner.tokenization.simple import SimpleTokenizer from montreal_forced_aligner.transcription.multiprocessing import ( + FOUND_FASTER_WHISPER, + FOUND_SPEECHBRAIN, + FOUND_TRANSFORMERS, EncoderASR, WhisperASR, + WhisperForConditionalGeneration, WhisperModel, + WhisperProcessor, get_suppressed_tokens, ) +from montreal_forced_aligner.vad.multiprocessing import segment_utterance_vad_speech_brain +from montreal_forced_aligner.vad.segmenter import SpeechbrainSegmenterMixin def transcribe_utterance_online( @@ -95,36 +104,111 @@ def transcribe_utterance_online( return ctm -def transcribe_utterance_online_whisper( +def transcribe_utterance_online_faster_whisper( model: WhisperModel, - utterance: KalpyUtterance, + segment: Segment, beam: int = 5, language: Language = Language.unknown, tokenizer: SimpleTokenizer = None, ) -> str: - segment = utterance.segment - waveform = segment.load_audio() + if not FOUND_FASTER_WHISPER: + raise Exception( + "Could not import faster-whisper, please ensure it is installed via `pip install faster-whisper`" + ) + waveform = segment.wave suppressed = get_suppressed_tokens(model) segments, info = model.transcribe( waveform, language=language.iso_code, beam_size=beam, suppress_tokens=suppressed, - temperature=0.0, + temperature=1.0, condition_on_previous_text=False, ) - text = " ".join([x.text for x in segments]) + texts = [] + for x in segments: + if x.no_speech_prob > 0.6: + continue + texts.append(x.text) + text = " ".join(texts) text = text.replace(" ", " ") if tokenizer is not None: text = tokenizer(text)[0] return text.strip() +def transcribe_utterance_online_whisper( + model: WhisperForConditionalGeneration, + processor: WhisperProcessor, + segment: Segment, + beam_size: int = 5, + language: Language = Language.unknown, + tokenizer: SimpleTokenizer = None, + segmenter: SpeechbrainSegmenterMixin = None, +) -> str: + if not FOUND_TRANSFORMERS: + raise Exception( + "Could not import transformers, please ensure it is installed via `conda install transformers`" + ) + raw_audio = [] + if segmenter is None: + audio = segment.wave.astype(np.float32) + raw_audio.append(audio) + else: + segments = segment_utterance_vad_speech_brain( + segment, segmenter.vad_model, segmenter.segmentation_options, allow_empty=True + ) + if len(segments) < 2: + raw_audio.append(segment.wave.astype(np.float32)) + else: + for s in segments: + raw_audio.append(s.wave.astype(np.float32)) + inputs = processor( + raw_audio, + return_tensors="pt", + truncation=True, + return_attention_mask=True, + sampling_rate=16_000, + pad_to_multiple_of=128, + device=model.device.type, + ) + inputs = inputs.to(model.device, model.dtype) + if language is not Language.unknown: + language = language.iso_code + else: + language = None + result = model.generate( + **inputs, + condition_on_prev_tokens=False, + temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) if segmenter is None else 0.0, + logprob_threshold=-1.0, + compression_ratio_threshold=1.35, + return_timestamps=False, + language=language, + ) + decoded = [] + special_ids = processor.tokenizer.all_special_ids + for r in result: + r = [t for t in r if t not in special_ids] + tokens = processor.tokenizer.convert_tokens_to_string( + processor.tokenizer.convert_ids_to_tokens(r) + ).strip() + decoded.append(tokens) + text = " ".join(decoded) + if tokenizer is not None: + text = tokenizer(text)[0] + return text.strip() + + def transcribe_utterance_online_speechbrain( model: typing.Union[WhisperASR, EncoderASR], utterance: KalpyUtterance, tokenizer: SimpleTokenizer = None, ) -> str: + if not FOUND_SPEECHBRAIN: + raise Exception( + "Could not import speechbrain, please ensure it is installed via `pip install speechbrain`" + ) segment = utterance.segment waveform = segment.load_audio() waveform = model.audio_normalizer(waveform, 16000).unsqueeze(0) diff --git a/montreal_forced_aligner/transcription/multiprocessing.py b/montreal_forced_aligner/transcription/multiprocessing.py index ea16b0b2..24def2f7 100644 --- a/montreal_forced_aligner/transcription/multiprocessing.py +++ b/montreal_forced_aligner/transcription/multiprocessing.py @@ -38,19 +38,22 @@ from montreal_forced_aligner.diarization.multiprocessing import UtteranceFileLoader from montreal_forced_aligner.tokenization.simple import SimpleTokenizer from montreal_forced_aligner.utils import mfa_open, thread_logger +from montreal_forced_aligner.vad.multiprocessing import segment_utterance_vad_speech_brain if TYPE_CHECKING: from dataclasses import dataclass + + from montreal_forced_aligner.vad.segmenter import SpeechbrainSegmenterMixin else: from dataclassy import dataclass try: from faster_whisper import WhisperModel - FOUND_WHISPER = True + FOUND_FASTER_WHISPER = True except ImportError: WhisperModel = None - FOUND_WHISPER = False + FOUND_FASTER_WHISPER = False try: from transformers import WhisperForConditionalGeneration, WhisperProcessor @@ -100,7 +103,8 @@ "LmRescoreFunction", "CreateHclgFunction", "FOUND_SPEECHBRAIN", - "FOUND_WHISPER", + "FOUND_FASTER_WHISPER", + "FOUND_TRANSFORMERS", "WhisperForConditionalGeneration", "WhisperASR", "EncoderASR", @@ -218,6 +222,7 @@ class WhisperArguments(MfaArguments): decode_options: MetaDict tokenizer: typing.Optional[SimpleTokenizer] cuda: bool + export_directory: typing.Optional[Path] @dataclass @@ -241,10 +246,12 @@ class WhisperCudaArguments(MfaArguments): model_id: str model: WhisperForConditionalGeneration processor: WhisperProcessor + segmenter: SpeechbrainSegmenterMixin language: Language decode_options: MetaDict tokenizer: typing.Optional[SimpleTokenizer] cuda: bool + export_directory: typing.Optional[Path] @dataclass @@ -270,6 +277,7 @@ class FasterWhisperArguments(MfaArguments): decode_options: MetaDict tokenizer: typing.Optional[SimpleTokenizer] cuda: bool + export_directory: typing.Optional[Path] @dataclass @@ -295,6 +303,7 @@ class FasterWhisperCudaArguments(MfaArguments): decode_options: MetaDict tokenizer: typing.Optional[SimpleTokenizer] cuda: bool + export_directory: typing.Optional[Path] @dataclass @@ -703,6 +712,9 @@ def _run(self) -> None: "EncoderASR", model, ), + huggingface_cache_dir=os.path.join( + config.TEMPORARY_DIRECTORY, "models", "hf_cache" + ), run_opts=run_opts, ) else: @@ -715,6 +727,9 @@ def _run(self) -> None: "WhisperASR", model, ), + huggingface_cache_dir=os.path.join( + config.TEMPORARY_DIRECTORY, "models", "hf_cache" + ), run_opts=run_opts, ) @@ -794,6 +809,9 @@ def __init__( stopped: threading.Event, finished_adding: threading.Event, processor: WhisperProcessor, + segmenter: SpeechbrainSegmenterMixin = None, + export_directory: Path = None, + device: str = "cpu", ): super().__init__() self.job_name = job_name @@ -802,6 +820,9 @@ def __init__( self.stopped = stopped self.finished_adding = finished_adding self.processor = processor + self.segmenter = segmenter + self.export_directory = export_directory + self.device = device def run(self) -> None: """ @@ -819,44 +840,87 @@ def run(self) -> None: Utterance.begin, Utterance.end, Utterance.channel, + File.relative_path, + File.name, ) .join(Utterance.file) .join(File.sound_file) - .filter(Utterance.duration <= 30) - .order_by(Utterance.duration.desc()) ) + if self.segmenter is None: + utterances = utterances.filter(Utterance.duration <= 30) + utterances = utterances.order_by(Utterance.duration.desc()) + else: + utterances = utterances.order_by(Utterance.speaker_id) if not utterances.count(): self.finished_adding.set() return raw_audio = [] utterance_ids = [] + export_paths = [] for u in utterances: if self.stopped.is_set(): break - utterance_ids.append(u[0]) segment = Segment(u[1], u[2], u[3], u[4]) - audio = segment.load_audio().astype(np.float32) - raw_audio.append(audio) - if len(utterance_ids) >= batch_size: + export_path = None + if self.export_directory is not None: + export_path = self.export_directory.joinpath(u[5], u[6] + ".lab") + if export_path.exists(): + continue + utterance_ids.append(u[0]) + if self.segmenter is None: + audio = segment.load_audio().astype(np.float32) + raw_audio.append(audio) + export_paths.append(export_path) + if len(utterance_ids) >= batch_size: + inputs = self.processor( + raw_audio, + return_tensors="pt", + truncation=True, + return_attention_mask=True, + sampling_rate=16_000, + device=self.device, + ) + self.return_q.put((utterance_ids, inputs)) + raw_audio = [] + utterance_ids = [] + export_paths = [] + else: + segments = segment_utterance_vad_speech_brain( + segment, + self.segmenter.vad_model, + self.segmenter.segmentation_options, + allow_empty=True, + ) + if not segments: + continue + if len(segments) == 1: + raw_audio.append(segment.wave.astype(np.float32)) + else: + for s in segments: + raw_audio.append(s.wave.astype(np.float32)) inputs = self.processor( raw_audio, return_tensors="pt", - truncation=False, + truncation=True, return_attention_mask=True, sampling_rate=16_000, + device=self.device, ) - self.return_q.put((utterance_ids, inputs)) + self.return_q.put((u[0], inputs, export_path)) raw_audio = [] utterance_ids = [] + export_paths = [] + if utterance_ids: inputs = self.processor( raw_audio, return_tensors="pt", - truncation=False, + truncation=True, return_attention_mask=True, sampling_rate=16_000, + device=self.device, ) - self.return_q.put((utterance_ids, inputs)) + self.return_q.put((utterance_ids, inputs, export_paths)) except Exception as e: self.return_q.put(e) finally: @@ -867,38 +931,42 @@ def get_suppressed_tokens( whisper_processor: typing.Union[WhisperProcessor, WhisperModel] ) -> typing.List[int]: suppressed = [] - roman_numeral_pattern = re.compile(r"(x+(vi+|i{2,}|i+v|x+))", flags=re.IGNORECASE) - case_roman_numeral_pattern = re.compile(r"([IXV]{2,}|i{2,}$|^.?x{2,}|^.?vi{2,}|\d)") + import unicodedata + + alpha_pattern = re.compile(r"\w", flags=re.UNICODE) + roman_numeral_pattern = re.compile(r"^(x+(vi+|i+|i?v|x+))$", flags=re.IGNORECASE) + case_roman_numeral_pattern = re.compile(r"(^[IXV]{2,}$|^[xvi]+i$|^x{2,}$|\d)") + + def _should_suppress(t): + if t.startswith("<|"): + return False + if any(unicodedata.category(c) in {"Mn", "Mc"} for c in t): + return False + if ( + roman_numeral_pattern.search(t) + or case_roman_numeral_pattern.search(t) + or re.match(r"^[XV]$", t) + or not alpha_pattern.search(t) + ): + return True + return False + if isinstance(whisper_processor, WhisperProcessor): - for token_id, token in whisper_processor.tokenizer.decoder.items(): - if token_id in whisper_processor.tokenizer.all_special_ids: - continue + for token_id in range(whisper_processor.tokenizer.vocab_size): + token = whisper_processor.tokenizer.convert_tokens_to_string( + whisper_processor.tokenizer.convert_ids_to_tokens([token_id]) + ).strip() if not token: continue - if not token.startswith("<|"): - if ( - not token.isalpha() - or roman_numeral_pattern.search(token) - or case_roman_numeral_pattern.search(token) - or re.match(r"^.?[XV]$", token) - ): - suppressed.append(token_id) + if _should_suppress(token): + suppressed.append(token_id) else: - i = 0 - while True: - token = whisper_processor.hf_tokenizer.id_to_token(i) - if token is None: - break - if not token.startswith("<|"): - if ( - not token.isalpha() - or roman_numeral_pattern.search(token) - or case_roman_numeral_pattern.search(token) - or re.match(r"^.?[XV]$", token) - ): - suppressed.append(i) - i += 1 - + for token_id in range(whisper_processor.hf_tokenizer.eot): + token = whisper_processor.hf_tokenizer.decode([token_id]).strip() + if not token: + continue + if _should_suppress(token): + suppressed.append(token_id) return suppressed @@ -984,9 +1052,6 @@ def _run(self) -> None: **transcribe_opts, ) text = " ".join([x.text for x in segments]) - del waveform - del segments - del info if self.tokenizer is not None: text = self.tokenizer(text)[0] self.callback((u[0], text)) @@ -1020,15 +1085,19 @@ class WhisperAsrFunction(KaldiFunction): def __init__(self, args: typing.Union[WhisperArguments, WhisperCudaArguments]): super().__init__(args) self.working_directory = args.working_directory + self.working_directory = args.working_directory self.cuda = args.cuda self.model_id = args.model_id self.model = None self.processor = None + self.segmenter = None self.language = args.language self.decode_options = args.decode_options + self.export_directory = args.export_directory if isinstance(args, WhisperCudaArguments): self.model = args.model self.processor = args.processor + self.segmenter = args.segmenter self.tokenizer = args.tokenizer def _run(self) -> None: @@ -1053,6 +1122,7 @@ def _run(self) -> None: model.generation_config.forced_decoder_ids = None if self.cuda: model.to("cuda") + special_ids = processor.tokenizer.all_special_ids return_q = queue.Queue(2) finished_adding = threading.Event() stopped = threading.Event() @@ -1063,10 +1133,16 @@ def _run(self) -> None: stopped, finished_adding, processor, + segmenter=self.segmenter, + export_directory=self.export_directory, + device="cuda" if self.cuda else "cpu", ) loader.start() exception = None current_index = 0 + cache_threshold = 10 + if self.segmenter is None: + cache_threshold = 100 while True: try: batch = return_q.get(timeout=1) @@ -1080,34 +1156,47 @@ def _run(self) -> None: exception = batch stopped.set() continue - utterance_ids, inputs = batch - if self.cuda: - inputs = inputs.to("cuda", torch.float16) + utterance_ids, inputs, export_paths = batch + inputs = inputs.to(model.device, model.dtype) result = model.generate( **inputs, condition_on_prev_tokens=False, - temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0), + temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) if self.segmenter is None else 0.0, logprob_threshold=-1.0, compression_ratio_threshold=1.35, return_timestamps=False, language=language, - pad_token_id=processor.tokenizer.eos_token_id, ) - decoded = processor.batch_decode( - result, skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - for i, u_id in enumerate(utterance_ids): - text = decoded[i] + decoded = [] + for r in result: + r = [t for t in r if t not in special_ids] + tokens = processor.tokenizer.convert_tokens_to_string( + processor.tokenizer.convert_ids_to_tokens(r) + ).strip() + decoded.append(tokens) + if isinstance(utterance_ids, list): + for i, u_id in enumerate(utterance_ids): + text = decoded[i] + if self.tokenizer is not None: + text = self.tokenizer(text)[0] + if export_paths[i] is not None: + export_paths[i].parent.mkdir(parents=True, exist_ok=True) + with mfa_open(export_paths[i], "w") as f: + f.write(text) + self.callback((int(u_id), text)) + else: + text = " ".join(decoded) if self.tokenizer is not None: text = self.tokenizer(text)[0] - self.callback((int(u_id), text)) - del utterance_ids - del inputs - del result - del decoded + + if export_paths is not None: + export_paths.parent.mkdir(parents=True, exist_ok=True) + with mfa_open(export_paths, "w") as f: + f.write(text) + self.callback((utterance_ids, text)) current_index += 1 - if current_index > 10: + if False and current_index > cache_threshold: torch.cuda.empty_cache() current_index = 0 diff --git a/montreal_forced_aligner/transcription/transcriber.py b/montreal_forced_aligner/transcription/transcriber.py index cd0e534e..3f5e8571 100644 --- a/montreal_forced_aligner/transcription/transcriber.py +++ b/montreal_forced_aligner/transcription/transcriber.py @@ -69,8 +69,9 @@ from montreal_forced_aligner.models import AcousticModel, LanguageModel from montreal_forced_aligner.textgrid import construct_output_path from montreal_forced_aligner.transcription.multiprocessing import ( + FOUND_FASTER_WHISPER, FOUND_SPEECHBRAIN, - FOUND_WHISPER, + FOUND_TRANSFORMERS, CarpaLmRescoreArguments, CarpaLmRescoreFunction, CreateHclgArguments, @@ -1427,16 +1428,45 @@ def export_files( if self.evaluation_mode: self.save_transcription_evaluation(self.export_output_directory) + def cleanup(self) -> None: + if self.cuda: + import torch + + gc.collect() + torch.cuda.empty_cache() + super().cleanup() + class WhisperTranscriber(HuggingFaceTranscriber): - ARCHITECTURES = ["distil-large-v3", "medium", "large-v3", "base", "tiny", "small"] + ARCHITECTURES = ["large-v3", "distil-large-v3", "medium", "base", "tiny", "small"] - def __init__(self, architecture: str = "distil-large-v3", **kwargs): - if not FOUND_WHISPER: + def __init__( + self, + architecture: str = "distil-large-v3", + vad: bool = False, + export_directory: Path = None, + **kwargs, + ): + from montreal_forced_aligner.vad.segmenter import ( + FOUND_SPEECHBRAIN, + SpeechbrainSegmenterMixin, + ) + + if not FOUND_TRANSFORMERS: logger.error( "Could not import transformers, please ensure it is installed via `conda install transformers`" ) sys.exit(1) + if not FOUND_FASTER_WHISPER: + logger.error( + "Could not import faster-whisper, please ensure it is installed via `pip install faster-whisper`" + ) + sys.exit(1) + if vad and not FOUND_SPEECHBRAIN: + logger.error( + "Could not import speechbrain, please ensure it is installed via `pip install speechbrain`" + ) + sys.exit(1) if architecture not in self.ARCHITECTURES: raise ModelError( f"The architecture {architecture} is not in: {', '.join(self.ARCHITECTURES)}" @@ -1444,8 +1474,113 @@ def __init__(self, architecture: str = "distil-large-v3", **kwargs): super().__init__(**kwargs) self.architecture = architecture self.model = None + self.segmenter = None + if vad: + self.segmenter = SpeechbrainSegmenterMixin(cuda=self.cuda) + self.segmenter.apply_energy_vad = True + self.segmenter.double_check = False + self.processor = None + self.export_directory = export_directory self.transcription_function = WhisperAsrFunction + def setup(self) -> None: + """ + Sets up the corpus and speaker classifier + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + if self.initialized: + return + self.setup_model() + super().setup() + + def setup_model(self, online=False, faster_whisper=False) -> None: + iso_code = self.language.iso_code + if iso_code is None and self.language is not Language.unknown: + raise ModelError( + f"The language {self.language.name} not in {', '.join(sorted(ISO_LANGUAGE_MAPPING.keys()))}" + ) + try: + if self.cuda: + config.update_configuration( + { + "USE_THREADING": True, + # "USE_MP": False, + } + ) + if faster_whisper: + if self.cuda: + run_opts = {"device": "cuda", "compute_type": "float16"} + self.model = WhisperModel( + self.architecture, + download_root=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "Whisper", + ), + local_files_only=False, + cpu_threads=config.NUM_JOBS, + num_workers=config.NUM_JOBS, + **run_opts, + ) + else: + # Download models if needed + m = WhisperModel( + self.architecture, + download_root=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "Whisper", + ), + local_files_only=False, + ) + if online: + self.model = m + else: + import torch + from transformers import WhisperProcessor, logging + from transformers.utils import is_flash_attn_2_available + + if not config.VERBOSE: + logging.set_verbosity_error() + if config.DEBUG: + logging.set_verbosity_debug() + attn_implementation = ( + "flash_attention_2" if is_flash_attn_2_available() else "sdpa" + ) + logger.debug(f"Using {attn_implementation} for attention") + if self.cuda: + self.model = WhisperForConditionalGeneration.from_pretrained( + f"openai/whisper-{self.architecture}", + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + attn_implementation=attn_implementation, + ) + self.model.to("cuda") + else: + m = WhisperForConditionalGeneration.from_pretrained( + f"openai/whisper-{self.architecture}" + ) + if online: + self.model = m + if self.cuda or online: + self.processor = WhisperProcessor.from_pretrained( + f"openai/whisper-{self.architecture}" + ) + suppressed = get_suppressed_tokens(self.processor) + self.model.generation_config.suppress_tokens += suppressed + self.model.generation_config.suppress_tokens = sorted( + set(self.model.generation_config.suppress_tokens) + ) + except Exception: + raise + raise ModelError( + f"Could not download whisper model with {self.architecture} and {self.language.name}" + ) + def transcribe_arguments(self): if self.cuda: return [ @@ -1457,10 +1592,12 @@ def transcribe_arguments(self): f"openai/whisper-{self.architecture}", self.model, self.processor, + self.segmenter, self.language, {"beam_size": 5}, self.tokenizer if self.evaluation_mode else None, self.cuda, + self.export_directory, ) ] return [ @@ -1469,11 +1606,12 @@ def transcribe_arguments(self): getattr(self, "session" if config.USE_THREADING else "db_string", ""), self.working_log_directory.joinpath(f"whisper_asr.{j.id}.log"), self.working_directory, - config.TEMPORARY_DIRECTORY.joinpath("models", "whisper", self.architecture), + f"openai/whisper-{self.architecture}", self.language, {"beam_size": 5}, self.tokenizer if self.evaluation_mode else None, self.cuda, + self.export_directory, ) for j in self.jobs ] @@ -1491,6 +1629,7 @@ def faster_whisper_arguments(self): {"beam_size": 5}, self.tokenizer if self.evaluation_mode else None, self.cuda, + self.export_directory, ) ] return [ @@ -1504,12 +1643,15 @@ def faster_whisper_arguments(self): {"beam_size": 5}, self.tokenizer if self.evaluation_mode else None, self.cuda, + self.export_directory, ) for j in self.jobs ] def transcribe_utterances(self) -> None: super().transcribe_utterances() + if self.segmenter is not None: + return workflow = self.current_workflow iso_code = self.language.iso_code if iso_code is None: @@ -1526,44 +1668,20 @@ def transcribe_utterances(self) -> None: if self.cuda: import torch - run_opts = {"device": "cuda", "compute_type": "float16"} del self.model gc.collect() torch.cuda.empty_cache() - self.model = WhisperModel( - self.architecture, - download_root=os.path.join( - config.TEMPORARY_DIRECTORY, - "models", - "Whisper", - ), - local_files_only=False, - cpu_threads=config.NUM_JOBS, - num_workers=config.NUM_JOBS, - **run_opts, - ) - config.update_configuration( - { - "USE_THREADING": True, - # "USE_MP": False, - } - ) - else: - # Download models if needed - _ = WhisperModel( - self.architecture, - download_root=os.path.join( - config.TEMPORARY_DIRECTORY, - "models", - "Whisper", - ), - local_files_only=False, - ) logger.info("Transcribing longer utterances (>30 seconds)...") for u_id, transcript in run_kaldi_function( FasterWhisperFunction, arguments, total_count=num_utterances ): update_mapping.append({"id": u_id, "transcription_text": transcript}) + if self.cuda: + import torch + + del self.model + gc.collect() + torch.cuda.empty_cache() if update_mapping: bulk_update(session, Utterance, update_mapping) session.commit() @@ -1585,84 +1703,13 @@ def transcribe_utterances(self) -> None: log_kaldi_errors(e.error_logs) e.update_log_file() raise - finally: - if self.cuda: - import torch - - del self.model - gc.collect() - torch.cuda.empty_cache() - - # noinspection PyTypeChecker - def setup(self) -> None: - """ - Sets up the corpus and speaker classifier - Raises - ------ - :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` - If there were any errors in running Kaldi binaries - """ - if self.initialized: - return - iso_code = self.language.iso_code - if iso_code is None: - raise ModelError( - f"The language {self.language.name} not in {', '.join(sorted(ISO_LANGUAGE_MAPPING.keys()))}" - ) - try: - from transformers import WhisperProcessor - - model_path = config.TEMPORARY_DIRECTORY.joinpath( - "models", "whisper", self.architecture - ) - if not model_path.exists(): - subprocess.call( - [ - "python", - "-m", - "transformers.models.whisper.convert_openai_to_hf", - "--checkpoint_path", - self.architecture, - "--pytorch_dump_folder_path", - str(model_path), - "--convert_preprocessor", - "True", - ] - ) - if self.cuda: - import torch - from transformers.utils import is_flash_attn_2_available - - self.processor = WhisperProcessor.from_pretrained(model_path) - suppressed = get_suppressed_tokens(self.processor) - attn_implementation = ( - "flash_attention_2" if is_flash_attn_2_available() else "sdpa" - ) - logger.debug(f"Using {attn_implementation} for attention") - - self.model = WhisperForConditionalGeneration.from_pretrained( - model_path, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - attn_implementation=attn_implementation, - ) - self.model.generation_config.suppress_tokens += suppressed - self.model.generation_config.suppress_tokens = list( - set(self.model.generation_config.suppress_tokens) - ) - self.model.generation_config.suppress_tokens.sort() - if self.language.iso_code is not None: - self.model.generation_config.forced_decoder_ids = None - self.model.to("cuda") - else: - _ = WhisperForConditionalGeneration.from_pretrained(model_path) - except Exception: - raise - raise ModelError( - f"Could not download whisper model with {self.architecture} and {self.language.name}" - ) - super().setup() + def cleanup(self) -> None: + if self.model is not None: + del self.model + if self.segmenter is not None: + del self.segmenter + super().cleanup() class SpeechbrainTranscriber(HuggingFaceTranscriber): @@ -1736,6 +1783,9 @@ def setup(self) -> None: savedir=os.path.join( config.TEMPORARY_DIRECTORY, "models", "EncoderASR", model_key ), + huggingface_cache_dir=os.path.join( + config.TEMPORARY_DIRECTORY, "models", "hf_cache" + ), ) else: # Download models if needed @@ -1747,6 +1797,9 @@ def setup(self) -> None: "WhisperASR", model_key, ), + huggingface_cache_dir=os.path.join( + config.TEMPORARY_DIRECTORY, "models", "hf_cache" + ), ) if self.cuda: self.model = m @@ -1757,3 +1810,8 @@ def setup(self) -> None: f"Could not download a speechbrain model with {self.architecture} and {self.language.name} ({model_key})" ) super().setup() + + def cleanup(self) -> None: + if self.model is not None: + del self.model + super().cleanup() diff --git a/montreal_forced_aligner/vad/multiprocessing.py b/montreal_forced_aligner/vad/multiprocessing.py index 1d6c44d4..cde7ac90 100644 --- a/montreal_forced_aligner/vad/multiprocessing.py +++ b/montreal_forced_aligner/vad/multiprocessing.py @@ -41,16 +41,121 @@ torch_logger = logging.getLogger("speechbrain.utils.train_logger") torch_logger.setLevel(logging.ERROR) import torch + import torchaudio try: from speechbrain.pretrained import VAD except ImportError: # speechbrain 1.0 from speechbrain.inference.VAD import VAD + class MfaVAD(VAD): + def energy_VAD( + self, + audio_file: typing.Union[str, Path, np.ndarray], + boundaries, + activation_th=0.5, + deactivation_th=0.0, + eps=1e-6, + ): + """Applies energy-based VAD within the detected speech segments.The neural + network VAD often creates longer segments and tends to merge segments that + are close with each other. + + The energy VAD post-processes can be useful for having a fine-grained voice + activity detection. + + The energy VAD computes the energy within the small chunks. The energy is + normalized within the segment to have mean 0.5 and +-0.5 of std. + This helps to set the energy threshold. + + Arguments + --------- + audio_file: path + Path of the audio file containing the recording. The file is read + with torchaudio. + boundaries: torch.Tensor + torch.Tensor containing the speech boundaries. It can be derived using the + get_boundaries method. + activation_th: float + A new speech segment is started it the energy is above activation_th. + deactivation_th: float + The segment is considered ended when the energy is <= deactivation_th. + eps: float + Small constant for numerical stability. + + Returns + ------- + new_boundaries + The new boundaries that are post-processed by the energy VAD. + """ + if not isinstance(audio_file, np.ndarray): + # Getting the total size of the input file + sample_rate, audio_len = self._get_audio_info(audio_file) + + if sample_rate != self.sample_rate: + raise ValueError( + "The detected sample rate is different from that set in the hparam file" + ) + else: + sample_rate = self.sample_rate + + # Computing the chunk length of the energy window + chunk_len = int(self.time_resolution * sample_rate) + new_boundaries = [] + + # Processing speech segments + for i in range(boundaries.shape[0]): + begin_sample = int(boundaries[i, 0] * sample_rate) + end_sample = int(boundaries[i, 1] * sample_rate) + seg_len = end_sample - begin_sample + + if not isinstance(audio_file, np.ndarray): + # Reading the speech segment + segment, _ = torchaudio.load( + audio_file, frame_offset=begin_sample, num_frames=seg_len + ) + else: + segment = audio_file[begin_sample : begin_sample + seg_len] + + # Create chunks + segment_chunks = self.create_chunks( + segment, chunk_size=chunk_len, chunk_stride=chunk_len + ) + + # Energy computation within each chunk + energy_chunks = segment_chunks.abs().sum(-1) + eps + energy_chunks = energy_chunks.log() + + # Energy normalization + energy_chunks = ( + (energy_chunks - energy_chunks.mean()) / (2 * energy_chunks.std()) + ) + 0.5 + energy_chunks = energy_chunks.unsqueeze(0).unsqueeze(2) + + # Apply threshold based on the energy value + energy_vad = self.apply_threshold( + energy_chunks, + activation_th=activation_th, + deactivation_th=deactivation_th, + ) + + # Get the boundaries + energy_boundaries = self.get_boundaries(energy_vad, output_value="seconds") + + # Get the final boundaries in the original signal + for j in range(energy_boundaries.shape[0]): + start_en = boundaries[i, 0] + energy_boundaries[j, 0] + end_end = boundaries[i, 0] + energy_boundaries[j, 1] + new_boundaries.append([start_en, end_end]) + + # Convert boundaries to tensor + new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device) + return new_boundaries + FOUND_SPEECHBRAIN = True except (ImportError, OSError): FOUND_SPEECHBRAIN = False - VAD = None + MfaVAD = None if TYPE_CHECKING: SpeakerCharacterType = Union[str, int] @@ -95,8 +200,8 @@ class SegmentTranscriptArguments(MfaArguments): def segment_utterance( - utterance: KalpyUtterance, - vad_model: VAD, + segment: Segment, + vad_model: MfaVAD, segmentation_options: MetaDict, mfcc_options: MetaDict = None, vad_options: MetaDict = None, @@ -107,9 +212,9 @@ def segment_utterance( Parameters ---------- - utterance: :class:`~kalpy.utterance.Utterance` - Utterance to split - vad_model: :class:`~speechbrain.pretrained.VAD` or None + segment: :class:`~kalpy.data.Segment` + Segment to split + vad_model: :class:`~montreal_forced_aligner.vad.multiprocessing.VAD` or None VAD model from SpeechBrain, if None, then Kaldi's energy-based VAD is used segmentation_options: dict[str, Any] Segmentation options @@ -120,19 +225,17 @@ def segment_utterance( Returns ------- - list[:class:`~kalpy.utterance.Utterance`] - Split utterances + list[:class:`~kalpy.data.Segment`] + Split segments """ if vad_model is None: - segments = segment_utterance_vad( - utterance, mfcc_options, vad_options, segmentation_options - ) + segments = segment_utterance_vad(segment, mfcc_options, vad_options, segmentation_options) else: segments = segment_utterance_vad_speech_brain( - utterance, vad_model, segmentation_options, allow_empty=allow_empty + segment, vad_model, segmentation_options, allow_empty=allow_empty ) if not segments: - return [utterance.segment] + return [segment] return segments @@ -140,7 +243,7 @@ def segment_utterance_transcript( acoustic_model: AcousticModel, utterance: KalpyUtterance, lexicon_compiler: LexiconCompiler, - vad_model: VAD, + vad_model: MfaVAD, segmentation_options: MetaDict, cmvn: DoubleMatrix = None, fmllr_trans: FloatMatrix = None, @@ -223,7 +326,7 @@ def segment_utterance_transcript( cmvn = cmvn_computer.compute_cmvn_from_features([utterance.mfccs]) current_transcript = utterance.transcript segments = segment_utterance( - utterance, vad_model, segmentation_options, mfcc_options, vad_options + utterance.segment, vad_model, segmentation_options, mfcc_options, vad_options ) if not segments: return [utterance] @@ -436,7 +539,7 @@ def merge_segments( def segment_utterance_vad( - utterance: KalpyUtterance, + segment: Segment, mfcc_options: MetaDict, vad_options: MetaDict, segmentation_options: MetaDict, @@ -448,15 +551,10 @@ def segment_utterance_vad( mfcc_options["dither"] = 0.0 mfcc_options["energy_floor"] = 0.0 mfcc_computer = MfccComputer(**mfcc_options) - feats = mfcc_computer.compute_mfccs_for_export(utterance.segment, compress=False) + feats = mfcc_computer.compute_mfccs_for_export(segment, compress=False) if adaptive: vad_options["energy_mean_scale"] = 0.0 mfccs = feats.numpy() - print(mfccs[:, 0]) - min_0, max_0 = mfccs[:, 0].min(), mfccs[:, 0].max() - range = max_0 - min_0 - thresh = (range * 0.6) + min_0 - print("THRESHOLD", thresh, min_0, max_0) vad_options["energy_threshold"] = mfccs[:, 0].mean() vad_computer = VadComputer(**vad_options) vad = vad_computer.compute_vad(feats).numpy() @@ -470,38 +568,37 @@ def segment_utterance_vad( new_segments = [] for s in segments: seg = Segment( - utterance.segment.file_path, - s.begin + utterance.segment.begin, - s.end + utterance.segment.begin, - utterance.segment.channel, + segment.file_path, + s.begin + segment.begin, + s.end + segment.begin, + segment.channel, ) new_segments.append(seg) return new_segments def segment_utterance_vad_speech_brain( - utterance: KalpyUtterance, - vad_model: VAD, + segment: Segment, + vad_model: MfaVAD, segmentation_options: MetaDict, allow_empty: bool = True, ) -> typing.List[Segment]: - y = utterance.segment.wave - prob_chunks = vad_model.get_speech_prob_chunk( - torch.tensor(y[np.newaxis, :], device=vad_model.device) - ).cpu() + y = segment.wave + prob_chunks = vad_model.get_speech_prob_chunk(torch.tensor(y[np.newaxis, :])).float() prob_th = vad_model.apply_threshold( prob_chunks, activation_th=segmentation_options["activation_th"], deactivation_th=segmentation_options["deactivation_th"], ).float() - # Compute the boundaries of the speech segments - boundaries = vad_model.get_boundaries(prob_th, output_value="seconds") - boundaries += utterance.segment.begin + # Compute the boundaries of the speech segments + boundaries = vad_model.get_boundaries(prob_th, output_value="seconds").cpu() + if segment.begin is not None: + boundaries += segment.begin # Apply energy-based VAD on the detected speech segments if segmentation_options["apply_energy_VAD"]: vad_boundaries = vad_model.energy_VAD( - utterance.segment.file_path, + segment.file_path, boundaries, activation_th=segmentation_options["en_activation_th"], deactivation_th=segmentation_options["en_deactivation_th"], @@ -524,22 +621,20 @@ def segment_utterance_vad_speech_brain( # Double check speech segments if segmentation_options["double_check"]: checked_boundaries = vad_model.double_check_speech_segments( - boundaries, utterance.segment.file_path, speech_th=segmentation_options["speech_th"] + boundaries, segment.file_path, speech_th=segmentation_options["speech_th"] ) if checked_boundaries.size(0) != 0 or allow_empty: boundaries = checked_boundaries - print(boundaries) boundaries[:, 0] -= round(segmentation_options["close_th"] / 2, 3) boundaries[:, 1] += round(segmentation_options["close_th"] / 2, 3) - boundaries = boundaries.numpy() segments = [] - for i in range(boundaries.shape[0]): + for i in range(boundaries.numpy().shape[0]): begin, end = boundaries[i] - begin = max(begin, 0) - end = min(end, utterance.segment.end) - seg = Segment( - utterance.segment.file_path, float(begin), float(end), utterance.segment.channel - ) + if i == 0: + begin = max(begin, 0) + if i == boundaries.numpy().shape[0] - 1: + end = min(end, segment.end) + seg = Segment(segment.file_path, float(begin), float(end), segment.channel) segments.append(seg) return segments diff --git a/montreal_forced_aligner/vad/segmenter.py b/montreal_forced_aligner/vad/segmenter.py index b109f7d5..be960e9b 100644 --- a/montreal_forced_aligner/vad/segmenter.py +++ b/montreal_forced_aligner/vad/segmenter.py @@ -39,7 +39,7 @@ from montreal_forced_aligner.utils import log_kaldi_errors, run_kaldi_function from montreal_forced_aligner.vad.multiprocessing import ( FOUND_SPEECHBRAIN, - VAD, + MfaVAD, SegmentTranscriptArguments, SegmentTranscriptFunction, SegmentVadArguments, @@ -58,7 +58,7 @@ class SpeechbrainSegmenterMixin: def __init__( self, - segment_padding: float = 0.01, + segment_padding: float = 0.1, large_chunk_size: float = 30, small_chunk_size: float = 0.05, overlap_small_chunk: bool = False, @@ -72,10 +72,9 @@ def __init__( en_deactivation_th: float = 0.4, speech_th: float = 0.5, cuda: bool = False, - speechbrain: bool = False, **kwargs, ): - if speechbrain and not FOUND_SPEECHBRAIN: + if not FOUND_SPEECHBRAIN: logger.error( "Could not import speechbrain, please ensure it is installed via `pip install speechbrain`" ) @@ -94,18 +93,17 @@ def __init__( self.en_deactivation_th = en_deactivation_th self.speech_th = speech_th self.cuda = cuda - self.speechbrain = speechbrain + self.speechbrain = True self.segment_padding = segment_padding self.vad_model = None - if self.speechbrain: - model_dir = os.path.join(config.TEMPORARY_DIRECTORY, "models", "VAD") - os.makedirs(model_dir, exist_ok=True) - run_opts = None - if self.cuda: - run_opts = {"device": "cuda"} - self.vad_model = VAD.from_hparams( - source="speechbrain/vad-crdnn-libriparty", savedir=model_dir, run_opts=run_opts - ) + model_dir = os.path.join(config.TEMPORARY_DIRECTORY, "models", "VAD") + os.makedirs(model_dir, exist_ok=True) + run_opts = None + if self.cuda: + run_opts = {"device": "cuda"} + self.vad_model = MfaVAD.from_hparams( + source="speechbrain/vad-crdnn-libriparty", savedir=model_dir, run_opts=run_opts + ) @property def segmentation_options(self) -> MetaDict: @@ -290,10 +288,12 @@ def segment_vad_speechbrain(self) -> None: for i in range(boundaries.shape[0]): old_utts.add(u.id) begin, end = boundaries[i, :] - begin -= self.segment_padding - end += self.segment_padding - begin = max(0.0, begin) - end = min(f.sound_file.duration, end) + begin -= round(kwargs["close_th"] / 2, 3) + end += round(kwargs["close_th"] / 2, 3) + if i == 0: + begin = max(0.0, begin) + if i == boundaries.shape[0] - 1: + end = min(f.sound_file.duration, end) new_utts.append( { "id": utt_index, @@ -450,7 +450,7 @@ def segment_utterance(self, utterance_id: int, allow_empty: bool = True): utterance = full_load_utterance(session, utterance_id) new_utterances = segment_utterance( - utterance.to_kalpy(), + utterance.to_kalpy().segment, self.vad_model if self.speechbrain else None, self.segmentation_options, mfcc_options=self.mfcc_options if not self.speechbrain else None, @@ -654,7 +654,6 @@ def segment_transcript(self, utterance_id: int): if not pronunciations: pronunciations = [self.oov_phone] lexicon_compiler.word_table.add_symbol(w) - print(w, ps, pronunciations) for p in pronunciations: lexicon_compiler.pronunciations.append( KalpyPronunciation( diff --git a/pyproject.toml b/pyproject.toml index d339d8bc..47ae3f9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,11 @@ write_to = "montreal_forced_aligner/_version.py" [tool.black] line-length = 99 +[tool.flake8] +max-line-length = 99 +extend-ignore = ["D203", "E203", "E251", "E266", "E302", "E305", "E401", "E402", "E501", "F401", "F403", "W503"] +exclude = [".git", "__pycache__", "dist", "build"] + [tool.isort] line_length = 99 profile = "black" diff --git a/tests/test_commandline_transcribe.py b/tests/test_commandline_transcribe.py index 66e25191..88cc7b31 100644 --- a/tests/test_commandline_transcribe.py +++ b/tests/test_commandline_transcribe.py @@ -4,7 +4,11 @@ import pytest from montreal_forced_aligner.command_line.mfa import mfa_cli -from montreal_forced_aligner.transcription.multiprocessing import FOUND_SPEECHBRAIN, FOUND_WHISPER +from montreal_forced_aligner.transcription.multiprocessing import ( + FOUND_FASTER_WHISPER, + FOUND_SPEECHBRAIN, + FOUND_TRANSFORMERS, +) def test_transcribe( @@ -93,8 +97,10 @@ def test_transcribe_whisper( temp_dir, db_setup, ): - if not FOUND_WHISPER: + if not FOUND_TRANSFORMERS: pytest.skip("transformers not installed") + if not FOUND_FASTER_WHISPER: + pytest.skip("faster-whisper not installed") output_path = generated_dir.joinpath("transcribe_test_whisper") command = [ "transcribe_whisper", From 70faa9775a14436f197f7a653096a9af8932df7d Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Fri, 6 Sep 2024 18:36:05 -0700 Subject: [PATCH 04/16] Somewhat optimized whisper pipeline --- environment.yml | 8 +- .../command_line/transcribe.py | 25 +- .../dictionary/multispeaker.py | 2 +- .../online/transcription.py | 104 +--- .../transcription/models.py | 227 ++++++++ .../transcription/multiprocessing.py | 519 ++++++------------ .../transcription/transcriber.py | 230 ++------ montreal_forced_aligner/utils.py | 16 + montreal_forced_aligner/vad/models.py | 403 ++++++++++++++ .../vad/multiprocessing.py | 198 +------ montreal_forced_aligner/vad/segmenter.py | 92 +--- tests/test_commandline_transcribe.py | 13 +- 12 files changed, 884 insertions(+), 953 deletions(-) create mode 100644 montreal_forced_aligner/transcription/models.py create mode 100644 montreal_forced_aligner/vad/models.py diff --git a/environment.yml b/environment.yml index a4b2efd2..826d4ef4 100644 --- a/environment.yml +++ b/environment.yml @@ -47,11 +47,13 @@ dependencies: - rich - rich-click - kalpy + # Tokenization dependencies - spacy - sudachipy - sudachidict-core - spacy-pkuseg - # Whisper dependencies + # WhisperX dependencies + - cudnn =8 - transformers - tokenizers - accelerate @@ -59,7 +61,7 @@ dependencies: - pip: - build - twine - - faster-whisper + # Tokenization dependencies - python-mecab-ko - jamo - pythainlp @@ -69,3 +71,5 @@ dependencies: - speechbrain - kenlm - pygtrie + # WhisperX dependencies + - whisperx diff --git a/montreal_forced_aligner/command_line/transcribe.py b/montreal_forced_aligner/command_line/transcribe.py index ddfc2d46..bb3e48bb 100644 --- a/montreal_forced_aligner/command_line/transcribe.py +++ b/montreal_forced_aligner/command_line/transcribe.py @@ -15,10 +15,7 @@ validate_language_model, ) from montreal_forced_aligner.data import Language -from montreal_forced_aligner.online.transcription import ( - transcribe_utterance_online_faster_whisper, - transcribe_utterance_online_whisper, -) +from montreal_forced_aligner.online.transcription import transcribe_utterance_online_whisper from montreal_forced_aligner.transcription.transcriber import ( SpeechbrainTranscriber, Transcriber, @@ -346,22 +343,12 @@ def transcribe_whisper_cli(context, **kwargs) -> None: try: if not input_path.is_dir(): segment = Segment(input_path) - faster_whisper = segment.wave.shape[0] / 16_000 > 30 - faster_whisper = False - transcriber.setup_model(online=True, faster_whisper=faster_whisper) + transcriber.setup_model(online=True) - if faster_whisper: - text = transcribe_utterance_online_faster_whisper( - transcriber.model, segment, language=transcriber.language - ) - else: - text = transcribe_utterance_online_whisper( - transcriber.model, - transcriber.processor, - segment, - language=transcriber.language, - segmenter=transcriber.segmenter, - ) + text = transcribe_utterance_online_whisper( + transcriber.model, + segment, + ) if str(output_path) == "-": print(text) # noqa sys.exit(0) diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index dd9fff2f..e6a7e37b 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -867,7 +867,7 @@ def calculate_disambiguation(self) -> None: .filter(Word.dictionary_id == d.id) .filter(Word.included == True) # noqa .options(selectinload(Word.pronunciations)) - ) + ).all() for w in words: for p in w.pronunciations: pron = p.pronunciation.split() diff --git a/montreal_forced_aligner/online/transcription.py b/montreal_forced_aligner/online/transcription.py index 208e8b97..cf07b198 100644 --- a/montreal_forced_aligner/online/transcription.py +++ b/montreal_forced_aligner/online/transcription.py @@ -14,23 +14,16 @@ from kalpy.gmm.decode import GmmDecoder from kalpy.utterance import Utterance as KalpyUtterance -from montreal_forced_aligner.data import Language +from montreal_forced_aligner import config from montreal_forced_aligner.exceptions import AlignerError from montreal_forced_aligner.models import AcousticModel from montreal_forced_aligner.tokenization.simple import SimpleTokenizer +from montreal_forced_aligner.transcription.models import FOUND_WHISPERX, MfaFasterWhisperPipeline from montreal_forced_aligner.transcription.multiprocessing import ( - FOUND_FASTER_WHISPER, FOUND_SPEECHBRAIN, - FOUND_TRANSFORMERS, EncoderASR, WhisperASR, - WhisperForConditionalGeneration, - WhisperModel, - WhisperProcessor, - get_suppressed_tokens, ) -from montreal_forced_aligner.vad.multiprocessing import segment_utterance_vad_speech_brain -from montreal_forced_aligner.vad.segmenter import SpeechbrainSegmenterMixin def transcribe_utterance_online( @@ -104,97 +97,24 @@ def transcribe_utterance_online( return ctm -def transcribe_utterance_online_faster_whisper( - model: WhisperModel, - segment: Segment, - beam: int = 5, - language: Language = Language.unknown, - tokenizer: SimpleTokenizer = None, -) -> str: - if not FOUND_FASTER_WHISPER: - raise Exception( - "Could not import faster-whisper, please ensure it is installed via `pip install faster-whisper`" - ) - waveform = segment.wave - suppressed = get_suppressed_tokens(model) - segments, info = model.transcribe( - waveform, - language=language.iso_code, - beam_size=beam, - suppress_tokens=suppressed, - temperature=1.0, - condition_on_previous_text=False, - ) - texts = [] - for x in segments: - if x.no_speech_prob > 0.6: - continue - texts.append(x.text) - text = " ".join(texts) - text = text.replace(" ", " ") - if tokenizer is not None: - text = tokenizer(text)[0] - return text.strip() - - def transcribe_utterance_online_whisper( - model: WhisperForConditionalGeneration, - processor: WhisperProcessor, + model: MfaFasterWhisperPipeline, segment: Segment, - beam_size: int = 5, - language: Language = Language.unknown, tokenizer: SimpleTokenizer = None, - segmenter: SpeechbrainSegmenterMixin = None, ) -> str: - if not FOUND_TRANSFORMERS: + if not FOUND_WHISPERX: raise Exception( "Could not import transformers, please ensure it is installed via `conda install transformers`" ) - raw_audio = [] - if segmenter is None: - audio = segment.wave.astype(np.float32) - raw_audio.append(audio) - else: - segments = segment_utterance_vad_speech_brain( - segment, segmenter.vad_model, segmenter.segmentation_options, allow_empty=True - ) - if len(segments) < 2: - raw_audio.append(segment.wave.astype(np.float32)) - else: - for s in segments: - raw_audio.append(s.wave.astype(np.float32)) - inputs = processor( - raw_audio, - return_tensors="pt", - truncation=True, - return_attention_mask=True, - sampling_rate=16_000, - pad_to_multiple_of=128, - device=model.device.type, - ) - inputs = inputs.to(model.device, model.dtype) - if language is not Language.unknown: - language = language.iso_code - else: - language = None - result = model.generate( - **inputs, - condition_on_prev_tokens=False, - temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) if segmenter is None else 0.0, - logprob_threshold=-1.0, - compression_ratio_threshold=1.35, - return_timestamps=False, - language=language, + audio = segment.wave.astype(np.float32) + vad_segments = model.vad_model.segment_for_whisper(audio, **model._vad_params) + result = model.transcribe( + vad_segments, [0 for _ in range(len(vad_segments))], batch_size=config.NUM_JOBS ) - decoded = [] - special_ids = processor.tokenizer.all_special_ids - for r in result: - r = [t for t in r if t not in special_ids] - tokens = processor.tokenizer.convert_tokens_to_string( - processor.tokenizer.convert_ids_to_tokens(r) - ).strip() - decoded.append(tokens) - text = " ".join(decoded) + texts = [] + for seg in result[0]: + texts.append(seg["text"].strip()) + text = " ".join(texts) if tokenizer is not None: text = tokenizer(text)[0] return text.strip() diff --git a/montreal_forced_aligner/transcription/models.py b/montreal_forced_aligner/transcription/models.py new file mode 100644 index 00000000..471c63ec --- /dev/null +++ b/montreal_forced_aligner/transcription/models.py @@ -0,0 +1,227 @@ +"""Model classes for Transcription""" +from __future__ import annotations + +import re +import typing +import warnings + +import numpy as np + +try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import faster_whisper + from whisperx import asr + from whisperx.asr import FasterWhisperPipeline + + FOUND_WHISPERX = True + +except ImportError: + FasterWhisperPipeline = object + FOUND_WHISPERX = False + +if typing.TYPE_CHECKING: + import torch + + +class MfaFasterWhisperPipeline(FasterWhisperPipeline): + def __init__( + self, + model, + vad, + vad_params: dict, + options: typing.NamedTuple, + tokenizer=None, + device: typing.Union[int, str, torch.device] = -1, + framework: str = "pt", + language: typing.Optional[str] = None, + suppress_numerals: bool = False, + **kwargs, + ): + super().__init__( + model, + vad, + vad_params, + options, + tokenizer, + device, + framework, + language, + suppress_numerals, + **kwargs, + ) + self.base_suppress_tokens = self.options.suppress_tokens + if self.preset_language is not None: + self.load_tokenizer(task="transcribe", language=self.preset_language) + + def get_suppressed_tokens( + self, + ) -> typing.List[int]: + suppressed = [] + import unicodedata + + alpha_pattern = re.compile(r"\w", flags=re.UNICODE) + roman_numeral_pattern = re.compile(r"^(x+(vi+|i+|i?v|x+))$", flags=re.IGNORECASE) + case_roman_numeral_pattern = re.compile(r"(^[IXV]{2,}$|^[xvi]+i$|^x{2,}$|\d)") + + def _should_suppress(t): + if t.startswith("<|"): + return False + if any(unicodedata.category(c) in {"Mn", "Mc"} for c in t): + return False + if ( + roman_numeral_pattern.search(t) + or case_roman_numeral_pattern.search(t) + or re.match(r"^[XV]$", t) + or not alpha_pattern.search(t) + ): + return True + return False + + for token_id in range(self.tokenizer.eot): + token = self.tokenizer.decode([token_id]).strip() + if not token: + continue + if _should_suppress(token): + suppressed.append(token_id) + return suppressed + + def load_tokenizer(self, task, language): + self.tokenizer = faster_whisper.tokenizer.Tokenizer( + self.model.hf_tokenizer, self.model.model.is_multilingual, task=task, language=language + ) + if self.suppress_numerals: + numeral_symbol_tokens = self.get_suppressed_tokens() + new_suppressed_tokens = numeral_symbol_tokens + self.base_suppress_tokens + new_suppressed_tokens = sorted(set(new_suppressed_tokens)) + self.options = self.options._replace(suppress_tokens=new_suppressed_tokens) + + def transcribe( + self, + audio_batch: typing.List[typing.Dict[str, typing.Union[np.ndarray, float]]], + utterance_ids, + batch_size=None, + num_workers=0, + vad_segments=None, + ): + utterances = {} + batch_size = batch_size or self._batch_size + for idx, out in enumerate( + self.__call__(audio_batch, batch_size=batch_size, num_workers=num_workers) + ): + text = out["text"] + utterance_id = utterance_ids[idx] + if utterance_id not in utterances: + utterances[utterance_id] = [] + if batch_size in [0, 1, None]: + text = text[0] + utterances[utterance_id].append( + { + "text": text, + "start": round(audio_batch[idx]["start"], 3), + "end": round(audio_batch[idx]["end"], 3), + } + ) + + return utterances + + +def load_model( + whisper_arch, + device, + device_index=0, + compute_type="float16", + asr_options=None, + language: typing.Optional[str] = None, + vad_model=None, + vad_options=None, + model: typing.Optional[asr.WhisperModel] = None, + download_root=None, + threads=4, +): + """Load a Whisper model for inference. + Args: + whisper_arch: str - The name of the Whisper model to load. + device: str - The device to load the model on. + compute_type: str - The compute type to use for the model. + options: dict - A dictionary of options to use for the model. + language: str - The language of the model. (use English for now) + vad_model_fp: str - File path to the VAD model to use + model: Optional[WhisperModel] - The WhisperModel instance to use. + download_root: Optional[str] - The root directory to download the model to. + threads: int - The number of cpu threads to use per worker, e.g. will be multiplied by num workers. + Returns: + A Whisper pipeline. + """ + + if whisper_arch.endswith(".en"): + language = "en" + + model = model or asr.WhisperModel( + whisper_arch, + device=device, + device_index=device_index, + compute_type=compute_type, + download_root=download_root, + cpu_threads=threads, + ) + + default_asr_options = { + "beam_size": 5, + "best_of": 5, + "patience": 1, + "length_penalty": 1, + "repetition_penalty": 1, + "no_repeat_ngram_size": 0, + "temperatures": [0.0, 0.2, 0.4, 0.6, 0.8, 1.0], + "compression_ratio_threshold": 2.4, + "log_prob_threshold": -1.0, + "no_speech_threshold": 0.6, + "condition_on_previous_text": False, + "prompt_reset_on_temperature": 0.5, + "initial_prompt": None, + "prefix": None, + "suppress_blank": True, + "suppress_tokens": [-1], + "without_timestamps": True, + "max_initial_timestamp": 0.0, + "word_timestamps": False, + "prepend_punctuations": "\"'“¿([{-", + "append_punctuations": "\"'.。,,!!??::”)]}、", + "suppress_numerals": True, + "max_new_tokens": None, + "clip_timestamps": None, + "hallucination_silence_threshold": None, + } + + if asr_options is not None: + default_asr_options.update(asr_options) + + suppress_numerals = default_asr_options["suppress_numerals"] + del default_asr_options["suppress_numerals"] + + default_asr_options = faster_whisper.transcribe.TranscriptionOptions(**default_asr_options) + + default_vad_options = { + "apply_energy_VAD": False, + "double_check": False, + "activation_th": 0.5, + "deactivation_th": 0.25, + "en_activation_th": 0.5, + "en_deactivation_th": 0.4, + "speech_th": 0.5, + "close_th": 0.333, + "len_th": 0.333, + } + + if vad_options is not None: + default_vad_options.update(vad_options) + + return MfaFasterWhisperPipeline( + model=model, + vad=vad_model, + options=default_asr_options, + language=language, + suppress_numerals=suppress_numerals, + vad_params=default_vad_options, + ) diff --git a/montreal_forced_aligner/transcription/multiprocessing.py b/montreal_forced_aligner/transcription/multiprocessing.py index 24def2f7..adc169da 100644 --- a/montreal_forced_aligner/transcription/multiprocessing.py +++ b/montreal_forced_aligner/transcription/multiprocessing.py @@ -8,9 +8,9 @@ import logging import os import queue -import re import threading import typing +import warnings from pathlib import Path from typing import TYPE_CHECKING, Dict @@ -34,39 +34,19 @@ from montreal_forced_aligner import config from montreal_forced_aligner.abc import KaldiFunction, MetaDict from montreal_forced_aligner.data import Language, MfaArguments, PhoneType -from montreal_forced_aligner.db import File, Job, Phone, SoundFile, Utterance +from montreal_forced_aligner.db import File, Job, Phone, SoundFile, Speaker, Utterance from montreal_forced_aligner.diarization.multiprocessing import UtteranceFileLoader from montreal_forced_aligner.tokenization.simple import SimpleTokenizer +from montreal_forced_aligner.transcription.models import MfaFasterWhisperPipeline, load_model from montreal_forced_aligner.utils import mfa_open, thread_logger -from montreal_forced_aligner.vad.multiprocessing import segment_utterance_vad_speech_brain +from montreal_forced_aligner.vad.models import MfaVAD if TYPE_CHECKING: from dataclasses import dataclass - - from montreal_forced_aligner.vad.segmenter import SpeechbrainSegmenterMixin else: from dataclassy import dataclass try: - from faster_whisper import WhisperModel - - FOUND_FASTER_WHISPER = True -except ImportError: - WhisperModel = None - FOUND_FASTER_WHISPER = False - -try: - from transformers import WhisperForConditionalGeneration, WhisperProcessor - - FOUND_TRANSFORMERS = True -except ImportError: - WhisperForConditionalGeneration = None - WhisperProcessor = None - FOUND_TRANSFORMERS = False - -try: - import warnings - with warnings.catch_warnings(): warnings.simplefilter("ignore") torch_logger = logging.getLogger("speechbrain.utils.torch_audio_backend") @@ -103,9 +83,6 @@ "LmRescoreFunction", "CreateHclgFunction", "FOUND_SPEECHBRAIN", - "FOUND_FASTER_WHISPER", - "FOUND_TRANSFORMERS", - "WhisperForConditionalGeneration", "WhisperASR", "EncoderASR", "SpeechbrainAsrArguments", @@ -116,6 +93,8 @@ "WhisperAsrFunction", ] +logger = logging.getLogger("mfa") + @dataclass class CreateHclgArguments(MfaArguments): @@ -217,9 +196,8 @@ class WhisperArguments(MfaArguments): """ working_directory: Path - model_id: str + architecture: str language: Language - decode_options: MetaDict tokenizer: typing.Optional[SimpleTokenizer] cuda: bool export_directory: typing.Optional[Path] @@ -243,64 +221,7 @@ class WhisperCudaArguments(MfaArguments): """ working_directory: Path - model_id: str - model: WhisperForConditionalGeneration - processor: WhisperProcessor - segmenter: SpeechbrainSegmenterMixin - language: Language - decode_options: MetaDict - tokenizer: typing.Optional[SimpleTokenizer] - cuda: bool - export_directory: typing.Optional[Path] - - -@dataclass -class FasterWhisperArguments(MfaArguments): - """ - Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction` - - Parameters - ---------- - job_name: int - Integer ID of the job - session: :class:`sqlalchemy.orm.scoped_session` or str - SqlAlchemy scoped session or string for database connections - log_path: :class:`~pathlib.Path` - Path to save logging information during the run - working_directory: :class:`~pathlib.Path` - Current working directory - """ - - working_directory: Path - model_size: str - language: Language - decode_options: MetaDict - tokenizer: typing.Optional[SimpleTokenizer] - cuda: bool - export_directory: typing.Optional[Path] - - -@dataclass -class FasterWhisperCudaArguments(MfaArguments): - """ - Arguments for :class:`~montreal_forced_aligner.transcription.multiprocessing.CreateHclgFunction` - - Parameters - ---------- - job_name: int - Integer ID of the job - session: :class:`sqlalchemy.orm.scoped_session` or str - SqlAlchemy scoped session or string for database connections - log_path: :class:`~pathlib.Path` - Path to save logging information during the run - working_directory: :class:`~pathlib.Path` - Current working directory - """ - - working_directory: Path - model: WhisperModel - language: Language - decode_options: MetaDict + model: MfaFasterWhisperPipeline tokenizer: typing.Optional[SimpleTokenizer] cuda: bool export_directory: typing.Optional[Path] @@ -808,10 +729,8 @@ def __init__( return_q: queue.Queue, stopped: threading.Event, finished_adding: threading.Event, - processor: WhisperProcessor, - segmenter: SpeechbrainSegmenterMixin = None, + model: MfaFasterWhisperPipeline, export_directory: Path = None, - device: str = "cpu", ): super().__init__() self.job_name = job_name @@ -819,18 +738,14 @@ def __init__( self.return_q = return_q self.stopped = stopped self.finished_adding = finished_adding - self.processor = processor - self.segmenter = segmenter + self.model = model self.export_directory = export_directory - self.device = device def run(self) -> None: """ Run the waveform loading job """ - batch_size = config.NUM_JOBS - with self.session() as session: try: utterances = ( @@ -842,225 +757,100 @@ def run(self) -> None: Utterance.channel, File.relative_path, File.name, + Speaker.name, ) + .join(Utterance.speaker) .join(Utterance.file) .join(File.sound_file) ) - if self.segmenter is None: - utterances = utterances.filter(Utterance.duration <= 30) - utterances = utterances.order_by(Utterance.duration.desc()) - else: - utterances = utterances.order_by(Utterance.speaker_id) + utterances = utterances.order_by(Utterance.speaker_id) if not utterances.count(): self.finished_adding.set() return - raw_audio = [] - utterance_ids = [] - export_paths = [] for u in utterances: if self.stopped.is_set(): break segment = Segment(u[1], u[2], u[3], u[4]) export_path = None if self.export_directory is not None: - export_path = self.export_directory.joinpath(u[5], u[6] + ".lab") - if export_path.exists(): + export_path = self.export_directory.joinpath(u[5], u[6]) + if any(export_path.with_suffix(x).exists() for x in [".lab", ".TextGrid"]): continue - utterance_ids.append(u[0]) - if self.segmenter is None: - audio = segment.load_audio().astype(np.float32) - raw_audio.append(audio) - export_paths.append(export_path) - if len(utterance_ids) >= batch_size: - inputs = self.processor( - raw_audio, - return_tensors="pt", - truncation=True, - return_attention_mask=True, - sampling_rate=16_000, - device=self.device, - ) - self.return_q.put((utterance_ids, inputs)) - raw_audio = [] - utterance_ids = [] - export_paths = [] - else: - segments = segment_utterance_vad_speech_brain( - segment, - self.segmenter.vad_model, - self.segmenter.segmentation_options, - allow_empty=True, - ) - if not segments: - continue - if len(segments) == 1: - raw_audio.append(segment.wave.astype(np.float32)) - else: - for s in segments: - raw_audio.append(s.wave.astype(np.float32)) - inputs = self.processor( - raw_audio, - return_tensors="pt", - truncation=True, - return_attention_mask=True, - sampling_rate=16_000, - device=self.device, - ) - self.return_q.put((u[0], inputs, export_path)) - raw_audio = [] - utterance_ids = [] - export_paths = [] - - if utterance_ids: - inputs = self.processor( - raw_audio, - return_tensors="pt", - truncation=True, - return_attention_mask=True, - sampling_rate=16_000, - device=self.device, + audio = segment.load_audio().astype(np.float32) + segments = self.model.vad_model.segment_for_whisper( + audio, **self.model._vad_params ) - self.return_q.put((utterance_ids, inputs, export_paths)) + self.return_q.put((u[0], segments, export_path, u[7], u[2], u[3])) except Exception as e: self.return_q.put(e) finally: self.finished_adding.set() -def get_suppressed_tokens( - whisper_processor: typing.Union[WhisperProcessor, WhisperModel] -) -> typing.List[int]: - suppressed = [] - import unicodedata - - alpha_pattern = re.compile(r"\w", flags=re.UNICODE) - roman_numeral_pattern = re.compile(r"^(x+(vi+|i+|i?v|x+))$", flags=re.IGNORECASE) - case_roman_numeral_pattern = re.compile(r"(^[IXV]{2,}$|^[xvi]+i$|^x{2,}$|\d)") - - def _should_suppress(t): - if t.startswith("<|"): - return False - if any(unicodedata.category(c) in {"Mn", "Mc"} for c in t): - return False - if ( - roman_numeral_pattern.search(t) - or case_roman_numeral_pattern.search(t) - or re.match(r"^[XV]$", t) - or not alpha_pattern.search(t) - ): - return True - return False - - if isinstance(whisper_processor, WhisperProcessor): - for token_id in range(whisper_processor.tokenizer.vocab_size): - token = whisper_processor.tokenizer.convert_tokens_to_string( - whisper_processor.tokenizer.convert_ids_to_tokens([token_id]) - ).strip() - if not token: - continue - if _should_suppress(token): - suppressed.append(token_id) - else: - for token_id in range(whisper_processor.hf_tokenizer.eot): - token = whisper_processor.hf_tokenizer.decode([token_id]).strip() - if not token: - continue - if _should_suppress(token): - suppressed.append(token_id) - return suppressed - - -class FasterWhisperFunction(KaldiFunction): +class WhisperUtteranceVAD(threading.Thread): """ - Multiprocessing function for performing decoding - - See Also - -------- - :meth:`.TranscriberMixin.transcribe_utterances` - Main function that calls this function in parallel - :meth:`.TranscriberMixin.decode_arguments` - Job method for generating arguments for this function - :kaldi_src:`gmm-latgen-faster` - Relevant Kaldi binary + Helper process for loading utterance waveforms in parallel with embedding extraction Parameters ---------- - args: :class:`~montreal_forced_aligner.transcription.multiprocessing.DecodeArguments` - Arguments for the function + job_name: int + Job identifier + session: sqlalchemy.orm.scoped_session + Session + return_q: :class:`~queue.Queue` + Queue to put waveforms + stopped: :class:`~threading.Event` + Check for whether the process to exit gracefully + finished_adding: :class:`~threading.Event` + Check for whether the worker has processed all utterances """ - def __init__(self, args: typing.Union[FasterWhisperArguments, FasterWhisperCudaArguments]): - super().__init__(args) - self.working_directory = args.working_directory - self.cuda = args.cuda - self.model = None - self.language = args.language - self.decode_options = args.decode_options - if isinstance(args, FasterWhisperCudaArguments): - self.model = args.model - else: - self.model = args.model_size - self.tokenizer = args.tokenizer + def __init__( + self, + job_name: int, + job_q: queue.Queue, + return_q: queue.Queue, + stopped: threading.Event, + finished_adding: threading.Event, + model: MfaFasterWhisperPipeline, + export_directory: Path = None, + ): + super().__init__() + self.job_name = job_name + self.job_q = job_q + self.return_q = return_q + self.stopped = stopped + self.finished_adding = finished_adding + self.model = model + self.export_directory = export_directory - def _run(self) -> None: - """Run the function""" - model = self.model - if isinstance(model, str): - if self.cuda: - run_opts = {"device": "cuda", "compute_type": "float16"} - else: - run_opts = {"device": "cpu"} - model = WhisperModel( - model, - download_root=os.path.join( - config.TEMPORARY_DIRECTORY, - "models", - "Whisper", - ), - local_files_only=True, - **run_opts, - ) - transcribe_opts = {"language": None, "beam_size": self.decode_options["beam_size"]} - if self.language is not Language.unknown: - transcribe_opts["language"] = self.language.iso_code - suppressed = get_suppressed_tokens(model) - current_index = 0 - with self.session() as session, mfa_open(self.log_path, "w") as log_file: - log_file.write(f"Suppressed: {len(suppressed)}\n") - utterances = ( - session.query( - Utterance.id, - SoundFile.sound_file_path, - Utterance.begin, - Utterance.end, - Utterance.channel, - ) - .join(Utterance.file) - .join(File.sound_file) - .filter(Utterance.job_id == self.job_name) - .filter(Utterance.duration > 30) - ) - for u in utterances: - segment = Segment(u[1], u[2], u[3], u[4]) - waveform = segment.load_audio() - log_file.write(f"{u[0]}: {waveform.shape}\n") - segments, info = model.transcribe( - waveform, - condition_on_previous_text=False, - suppress_tokens=suppressed, - temperature=0.0, - **transcribe_opts, + def run(self) -> None: + """ + Run the waveform loading job + """ + + while True: + try: + batch = self.job_q.get(timeout=1) + except queue.Empty: + if self.finished_adding.is_set(): + break + continue + if self.stopped.is_set(): + continue + if isinstance(batch, Exception): + exception = batch + self.return_q.put(exception) + self.stopped.set() + continue + try: + utterance_id, audio, export_path, speaker_name, begin, end = batch + segments = self.model.vad_model.segment_for_whisper( + audio, **self.model._vad_params ) - text = " ".join([x.text for x in segments]) - if self.tokenizer is not None: - text = self.tokenizer(text)[0] - self.callback((u[0], text)) - log_file.write(f"{u[0]}: {text}\n") - log_file.flush() - current_index += 1 - if current_index > 50: - torch.cuda.empty_cache() - current_index = 0 + self.return_q.put((utterance_id, segments, export_path, speaker_name, begin, end)) + except Exception as e: + self.return_q.put(e) class WhisperAsrFunction(KaldiFunction): @@ -1087,43 +877,48 @@ def __init__(self, args: typing.Union[WhisperArguments, WhisperCudaArguments]): self.working_directory = args.working_directory self.working_directory = args.working_directory self.cuda = args.cuda - self.model_id = args.model_id + self.architecture = None self.model = None - self.processor = None - self.segmenter = None - self.language = args.language - self.decode_options = args.decode_options + self.language = None self.export_directory = args.export_directory if isinstance(args, WhisperCudaArguments): self.model = args.model - self.processor = args.processor - self.segmenter = args.segmenter + else: + self.language = args.language + self.architecture = args.architecture self.tokenizer = args.tokenizer def _run(self) -> None: """Run the function""" - processor = self.processor - if processor is None: - processor = WhisperProcessor.from_pretrained(self.model_id) - processor.tokenizer.add_prefix_space = False - language = None - if self.language is not Language.unknown: - language = self.language.iso_code model = self.model if model is None: - suppressed = get_suppressed_tokens(processor) - model = WhisperForConditionalGeneration.from_pretrained(self.model_id) - model.generation_config.suppress_tokens += suppressed - model.generation_config.suppress_tokens = list( - set(model.generation_config.suppress_tokens) + language = None + if self.language is not Language.unknown: + language = self.language.iso_code + run_opts = None + if self.cuda: + run_opts = {"device": "cuda"} + vad_model = MfaVAD.from_hparams( + source="speechbrain/vad-crdnn-libriparty", + savedir=os.path.join(config.TEMPORARY_DIRECTORY, "models", "VAD"), + run_opts=run_opts, + ) + model = load_model( + self.architecture, + device="cuda" if self.cuda else "cpu", + language=language, + vad_model=vad_model, + vad_options=None, + download_root=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "Whisper", + ), + threads=config.NUM_JOBS, ) - model.generation_config.suppress_tokens.sort() - if language is not None: - model.generation_config.forced_decoder_ids = None if self.cuda: model.to("cuda") - special_ids = processor.tokenizer.all_special_ids - return_q = queue.Queue(2) + return_q = queue.Queue(100) finished_adding = threading.Event() stopped = threading.Event() loader = WhisperUtteranceLoader( @@ -1132,73 +927,75 @@ def _run(self) -> None: return_q, stopped, finished_adding, - processor, - segmenter=self.segmenter, + model, export_directory=self.export_directory, - device="cuda" if self.cuda else "cpu", ) loader.start() exception = None - current_index = 0 - cache_threshold = 10 - if self.segmenter is None: - cache_threshold = 100 + while True: try: - batch = return_q.get(timeout=1) + vad_result = return_q.get(timeout=1) except queue.Empty: if finished_adding.is_set(): break continue if stopped.is_set(): continue - if isinstance(batch, Exception): - exception = batch + if isinstance(vad_result, Exception): + exception = vad_result stopped.set() continue - utterance_ids, inputs, export_paths = batch - inputs = inputs.to(model.device, model.dtype) - result = model.generate( - **inputs, - condition_on_prev_tokens=False, - temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) if self.segmenter is None else 0.0, - logprob_threshold=-1.0, - compression_ratio_threshold=1.35, - return_timestamps=False, - language=language, - ) - - decoded = [] - for r in result: - r = [t for t in r if t not in special_ids] - tokens = processor.tokenizer.convert_tokens_to_string( - processor.tokenizer.convert_ids_to_tokens(r) - ).strip() - decoded.append(tokens) - if isinstance(utterance_ids, list): - for i, u_id in enumerate(utterance_ids): - text = decoded[i] - if self.tokenizer is not None: - text = self.tokenizer(text)[0] - if export_paths[i] is not None: - export_paths[i].parent.mkdir(parents=True, exist_ok=True) - with mfa_open(export_paths[i], "w") as f: - f.write(text) - self.callback((int(u_id), text)) - else: - text = " ".join(decoded) - if self.tokenizer is not None: - text = self.tokenizer(text)[0] + try: + utterance_id, segments, export_path, speaker_name, begin, end = vad_result + result = model.transcribe( + segments, [utterance_id] * len(segments), batch_size=config.NUM_JOBS + ) + for utterance_id, segments in result.items(): + texts = [] + for seg in segments: + seg["text"] = seg["text"].strip() + if self.tokenizer is not None: + seg["text"] = self.tokenizer(seg["text"])[0] + texts.append(seg["text"]) + text = " ".join(texts) + + if export_path is not None: + export_path: Path + export_path.parent.mkdir(parents=True, exist_ok=True) + if len(segments) == 1: + with mfa_open(export_path.with_suffix(".lab"), "w") as f: + f.write(text) + else: + from praatio import textgrid + + tg = textgrid.Textgrid() + tg.minTimestamp = begin + tg.maxTimestamp = end + tier = textgrid.IntervalTier( + speaker_name, + [ + textgrid.constants.Interval( + round(begin + x["start"], 3), + round(begin + x["end"], 3), + x["text"], + ) + for x in segments + ], + minT=begin, + maxT=end, + ) - if export_paths is not None: - export_paths.parent.mkdir(parents=True, exist_ok=True) - with mfa_open(export_paths, "w") as f: - f.write(text) - self.callback((utterance_ids, text)) - current_index += 1 - if False and current_index > cache_threshold: - torch.cuda.empty_cache() - current_index = 0 + tg.addTier(tier) + tg.save( + str(export_path.with_suffix(".TextGrid")), + includeBlankSpaces=True, + format="short_textgrid", + ) + self.callback((utterance_id, text)) + except Exception as e: + exception = e + stopped.set() loader.join() if exception: diff --git a/montreal_forced_aligner/transcription/transcriber.py b/montreal_forced_aligner/transcription/transcriber.py index 3f5e8571..f5f43f34 100644 --- a/montreal_forced_aligner/transcription/transcriber.py +++ b/montreal_forced_aligner/transcription/transcriber.py @@ -68,10 +68,9 @@ ) from montreal_forced_aligner.models import AcousticModel, LanguageModel from montreal_forced_aligner.textgrid import construct_output_path +from montreal_forced_aligner.transcription.models import FOUND_WHISPERX, load_model from montreal_forced_aligner.transcription.multiprocessing import ( - FOUND_FASTER_WHISPER, FOUND_SPEECHBRAIN, - FOUND_TRANSFORMERS, CarpaLmRescoreArguments, CarpaLmRescoreFunction, CreateHclgArguments, @@ -81,9 +80,6 @@ DecodePhoneArguments, DecodePhoneFunction, EncoderASR, - FasterWhisperArguments, - FasterWhisperCudaArguments, - FasterWhisperFunction, FinalFmllrArguments, FinalFmllrFunction, FmllrRescoreArguments, @@ -101,9 +97,6 @@ WhisperASR, WhisperAsrFunction, WhisperCudaArguments, - WhisperForConditionalGeneration, - WhisperModel, - get_suppressed_tokens, ) from montreal_forced_aligner.utils import ( KaldiProcessWorker, @@ -111,6 +104,7 @@ run_kaldi_function, thirdparty_binary, ) +from montreal_forced_aligner.vad.models import SpeechbrainSegmenterMixin if TYPE_CHECKING: from montreal_forced_aligner.abc import MetaDict @@ -1437,36 +1431,19 @@ def cleanup(self) -> None: super().cleanup() -class WhisperTranscriber(HuggingFaceTranscriber): +class WhisperTranscriber(HuggingFaceTranscriber, SpeechbrainSegmenterMixin): ARCHITECTURES = ["large-v3", "distil-large-v3", "medium", "base", "tiny", "small"] def __init__( self, - architecture: str = "distil-large-v3", - vad: bool = False, + architecture: str = "large-v3", export_directory: Path = None, **kwargs, ): - from montreal_forced_aligner.vad.segmenter import ( - FOUND_SPEECHBRAIN, - SpeechbrainSegmenterMixin, - ) - - if not FOUND_TRANSFORMERS: - logger.error( - "Could not import transformers, please ensure it is installed via `conda install transformers`" + if not FOUND_WHISPERX: + raise Exception( + "Could not import whisperx, please ensure it is installed via `pip install whisperx`" ) - sys.exit(1) - if not FOUND_FASTER_WHISPER: - logger.error( - "Could not import faster-whisper, please ensure it is installed via `pip install faster-whisper`" - ) - sys.exit(1) - if vad and not FOUND_SPEECHBRAIN: - logger.error( - "Could not import speechbrain, please ensure it is installed via `pip install speechbrain`" - ) - sys.exit(1) if architecture not in self.ARCHITECTURES: raise ModelError( f"The architecture {architecture} is not in: {', '.join(self.ARCHITECTURES)}" @@ -1475,10 +1452,6 @@ def __init__( self.architecture = architecture self.model = None self.segmenter = None - if vad: - self.segmenter = SpeechbrainSegmenterMixin(cuda=self.cuda) - self.segmenter.apply_energy_vad = True - self.segmenter.double_check = False self.processor = None self.export_directory = export_directory self.transcription_function = WhisperAsrFunction @@ -1497,84 +1470,36 @@ def setup(self) -> None: self.setup_model() super().setup() - def setup_model(self, online=False, faster_whisper=False) -> None: + def setup_model(self, online=False) -> None: iso_code = self.language.iso_code if iso_code is None and self.language is not Language.unknown: raise ModelError( f"The language {self.language.name} not in {', '.join(sorted(ISO_LANGUAGE_MAPPING.keys()))}" ) try: + vad_model = None if self.cuda: config.update_configuration( { "USE_THREADING": True, - # "USE_MP": False, } ) - if faster_whisper: - if self.cuda: - run_opts = {"device": "cuda", "compute_type": "float16"} - self.model = WhisperModel( - self.architecture, - download_root=os.path.join( - config.TEMPORARY_DIRECTORY, - "models", - "Whisper", - ), - local_files_only=False, - cpu_threads=config.NUM_JOBS, - num_workers=config.NUM_JOBS, - **run_opts, - ) - else: - # Download models if needed - m = WhisperModel( - self.architecture, - download_root=os.path.join( - config.TEMPORARY_DIRECTORY, - "models", - "Whisper", - ), - local_files_only=False, - ) - if online: - self.model = m - else: - import torch - from transformers import WhisperProcessor, logging - from transformers.utils import is_flash_attn_2_available - - if not config.VERBOSE: - logging.set_verbosity_error() - if config.DEBUG: - logging.set_verbosity_debug() - attn_implementation = ( - "flash_attention_2" if is_flash_attn_2_available() else "sdpa" - ) - logger.debug(f"Using {attn_implementation} for attention") - if self.cuda: - self.model = WhisperForConditionalGeneration.from_pretrained( - f"openai/whisper-{self.architecture}", - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - attn_implementation=attn_implementation, - ) - self.model.to("cuda") - else: - m = WhisperForConditionalGeneration.from_pretrained( - f"openai/whisper-{self.architecture}" - ) - if online: - self.model = m - if self.cuda or online: - self.processor = WhisperProcessor.from_pretrained( - f"openai/whisper-{self.architecture}" - ) - suppressed = get_suppressed_tokens(self.processor) - self.model.generation_config.suppress_tokens += suppressed - self.model.generation_config.suppress_tokens = sorted( - set(self.model.generation_config.suppress_tokens) - ) + vad_model = self.vad_model + m = load_model( + self.architecture, + device="cuda" if self.cuda else "cpu", + language=iso_code, + vad_model=vad_model, + vad_options=self.segmentation_options, + download_root=os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "Whisper", + ), + threads=config.NUM_JOBS, + ) + if self.cuda or online: + self.model = m except Exception: raise raise ModelError( @@ -1589,12 +1514,7 @@ def transcribe_arguments(self): getattr(self, "session", ""), self.working_log_directory.joinpath("whisper_asr.log"), self.working_directory, - f"openai/whisper-{self.architecture}", self.model, - self.processor, - self.segmenter, - self.language, - {"beam_size": 5}, self.tokenizer if self.evaluation_mode else None, self.cuda, self.export_directory, @@ -1602,45 +1522,12 @@ def transcribe_arguments(self): ] return [ WhisperArguments( - j.id, - getattr(self, "session" if config.USE_THREADING else "db_string", ""), - self.working_log_directory.joinpath(f"whisper_asr.{j.id}.log"), - self.working_directory, - f"openai/whisper-{self.architecture}", - self.language, - {"beam_size": 5}, - self.tokenizer if self.evaluation_mode else None, - self.cuda, - self.export_directory, - ) - for j in self.jobs - ] - - def faster_whisper_arguments(self): - if self.cuda: - return [ - FasterWhisperCudaArguments( - 1, - getattr(self, "session", ""), - self.working_log_directory.joinpath("whisper_asr.log"), - self.working_directory, - self.model, - self.language, - {"beam_size": 5}, - self.tokenizer if self.evaluation_mode else None, - self.cuda, - self.export_directory, - ) - ] - return [ - FasterWhisperArguments( j.id, getattr(self, "session" if config.USE_THREADING else "db_string", ""), self.working_log_directory.joinpath(f"whisper_asr.{j.id}.log"), self.working_directory, self.architecture, self.language, - {"beam_size": 5}, self.tokenizer if self.evaluation_mode else None, self.cuda, self.export_directory, @@ -1648,67 +1535,16 @@ def faster_whisper_arguments(self): for j in self.jobs ] - def transcribe_utterances(self) -> None: - super().transcribe_utterances() - if self.segmenter is not None: - return - workflow = self.current_workflow - iso_code = self.language.iso_code - if iso_code is None: - raise ModelError( - f"The language {self.language.name} not in {', '.join(sorted(ISO_LANGUAGE_MAPPING.keys()))}" - ) - try: - arguments = self.faster_whisper_arguments() - update_mapping = [] - with self.session() as session: - num_utterances = session.query(Utterance).filter(Utterance.duration > 30).count() - if not num_utterances: - return - if self.cuda: - import torch - - del self.model - gc.collect() - torch.cuda.empty_cache() - logger.info("Transcribing longer utterances (>30 seconds)...") - for u_id, transcript in run_kaldi_function( - FasterWhisperFunction, arguments, total_count=num_utterances - ): - update_mapping.append({"id": u_id, "transcription_text": transcript}) - if self.cuda: - import torch - - del self.model - gc.collect() - torch.cuda.empty_cache() - if update_mapping: - bulk_update(session, Utterance, update_mapping) - session.commit() - if self.evaluation_mode: - os.makedirs(self.working_log_directory, exist_ok=True) - self.evaluate_transcriptions() - with self.session() as session: - session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( - {"done": True} - ) - session.commit() - except Exception as e: - with self.session() as session: - session.query(CorpusWorkflow).filter(CorpusWorkflow.id == workflow.id).update( - {"dirty": True} - ) - session.commit() - if isinstance(e, KaldiProcessingError): - log_kaldi_errors(e.error_logs) - e.update_log_file() - raise - def cleanup(self) -> None: if self.model is not None: del self.model - if self.segmenter is not None: - del self.segmenter + if self.vad_model is not None: + del self.vad_model + if self.cuda: + import torch + + gc.collect() + torch.cuda.empty_cache() super().cleanup() @@ -1718,7 +1554,7 @@ class SpeechbrainTranscriber(HuggingFaceTranscriber): def __init__(self, architecture: str = "whisper-medium", **kwargs): if not FOUND_SPEECHBRAIN: logger.error( - "Could not import faster_whisper, please ensure it is installed via `pip install faster-whisper`" + "Could not import speechbrain, please ensure it is installed via `pip install speechbrain`" ) sys.exit(1) if architecture not in self.ARCHITECTURES: diff --git a/montreal_forced_aligner/utils.py b/montreal_forced_aligner/utils.py index d5acb8e0..048719e1 100644 --- a/montreal_forced_aligner/utils.py +++ b/montreal_forced_aligner/utils.py @@ -714,6 +714,13 @@ def run_kaldi_function( logger.debug("Received ctrl+c event") stopped.set() error_dict["main_thread"] = e + import sys + import traceback + + exc_type, exc_value, exc_traceback = sys.exc_info() + logger.debug( + "\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + ) continue finally: @@ -768,6 +775,15 @@ def run_kaldi_function( logger.debug("Received ctrl+c event") stopped.set() error_dict["main_thread"] = e + import sys + import traceback + + exc_type, exc_value, exc_traceback = sys.exc_info() + logger.debug( + "\n".join( + traceback.format_exception(exc_type, exc_value, exc_traceback) + ) + ) continue finally: diff --git a/montreal_forced_aligner/vad/models.py b/montreal_forced_aligner/vad/models.py new file mode 100644 index 00000000..25bee8b6 --- /dev/null +++ b/montreal_forced_aligner/vad/models.py @@ -0,0 +1,403 @@ +"""Model classes for Voice Activity Detection""" +from __future__ import annotations + +import logging +import os +import sys +import typing +import warnings +from pathlib import Path + +import numpy as np +from kalpy.data import Segment + +from montreal_forced_aligner import config + +if typing.TYPE_CHECKING: + from montreal_forced_aligner.abc import MetaDict + +try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + torch_logger = logging.getLogger("speechbrain.utils.torch_audio_backend") + torch_logger.setLevel(logging.ERROR) + torch_logger = logging.getLogger("speechbrain.utils.train_logger") + torch_logger.setLevel(logging.ERROR) + import torch + import torchaudio + + try: + from speechbrain.pretrained import VAD + except ImportError: # speechbrain 1.0 + from speechbrain.inference.VAD import VAD + + FOUND_SPEECHBRAIN = True +except (ImportError, OSError): + FOUND_SPEECHBRAIN = False + VAD = object + + +logger = logging.getLogger("mfa") + + +class MfaVAD(VAD): + def energy_VAD( + self, + audio_file: typing.Union[str, Path, np.ndarray, torch.Tensor], + boundaries, + activation_th=0.5, + deactivation_th=0.0, + eps=1e-6, + ): + """Applies energy-based VAD within the detected speech segments.The neural + network VAD often creates longer segments and tends to merge segments that + are close with each other. + + The energy VAD post-processes can be useful for having a fine-grained voice + activity detection. + + The energy VAD computes the energy within the small chunks. The energy is + normalized within the segment to have mean 0.5 and +-0.5 of std. + This helps to set the energy threshold. + + Arguments + --------- + audio_file: path + Path of the audio file containing the recording. The file is read + with torchaudio. + boundaries: torch.Tensor + torch.Tensor containing the speech boundaries. It can be derived using the + get_boundaries method. + activation_th: float + A new speech segment is started it the energy is above activation_th. + deactivation_th: float + The segment is considered ended when the energy is <= deactivation_th. + eps: float + Small constant for numerical stability. + + Returns + ------- + new_boundaries + The new boundaries that are post-processed by the energy VAD. + """ + if isinstance(audio_file, (str, Path)): + # Getting the total size of the input file + sample_rate, audio_len = self._get_audio_info(audio_file) + + if sample_rate != self.sample_rate: + raise ValueError( + "The detected sample rate is different from that set in the hparam file" + ) + else: + sample_rate = self.sample_rate + + # Computing the chunk length of the energy window + chunk_len = int(self.time_resolution * sample_rate) + new_boundaries = [] + + # Processing speech segments + for i in range(boundaries.shape[0]): + begin_sample = int(boundaries[i, 0] * sample_rate) + end_sample = int(boundaries[i, 1] * sample_rate) + seg_len = end_sample - begin_sample + if seg_len < chunk_len: + continue + if not isinstance(audio_file, torch.Tensor): + # Reading the speech segment + segment, _ = torchaudio.load( + audio_file, frame_offset=begin_sample, num_frames=seg_len + ) + else: + segment = audio_file[:, begin_sample : begin_sample + seg_len] + + # Create chunks + segment_chunks = self.create_chunks( + segment, chunk_size=chunk_len, chunk_stride=chunk_len + ) + + # Energy computation within each chunk + energy_chunks = segment_chunks.abs().sum(-1) + eps + energy_chunks = energy_chunks.log() + + # Energy normalization + energy_chunks = ( + (energy_chunks - energy_chunks.mean()) / (2 * energy_chunks.std()) + ) + 0.5 + energy_chunks = energy_chunks.unsqueeze(0).unsqueeze(2) + + # Apply threshold based on the energy value + energy_vad = self.apply_threshold( + energy_chunks, + activation_th=activation_th, + deactivation_th=deactivation_th, + ) + + # Get the boundaries + energy_boundaries = self.get_boundaries(energy_vad, output_value="seconds") + + # Get the final boundaries in the original signal + for j in range(energy_boundaries.shape[0]): + start_en = boundaries[i, 0] + energy_boundaries[j, 0] + end_end = boundaries[i, 0] + energy_boundaries[j, 1] + new_boundaries.append([start_en, end_end]) + + # Convert boundaries to tensor + new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device) + return new_boundaries + + def double_check_speech_segments(self, boundaries, audio_file, speech_th=0.5): + """Takes in input the boundaries of the detected speech segments and + double checks (using the neural VAD) that they actually contain speech. + + Arguments + --------- + boundaries: torch.Tensor + torch.Tensor containing the boundaries of the speech segments. + audio_file: path + The original audio file used to compute vad_out. + speech_th: float + Threshold on the mean posterior probability over which speech is + confirmed. Below that threshold, the segment is re-assigned to a + non-speech region. + + Returns + ------- + new_boundaries + The boundaries of the segments where speech activity is confirmed. + """ + + if isinstance(audio_file, (str, Path)): + # Getting the total size of the input file + sample_rate, audio_len = self._get_audio_info(audio_file) + + if sample_rate != self.sample_rate: + raise ValueError( + "The detected sample rate is different from that set in the hparam file" + ) + else: + sample_rate = self.sample_rate + + # Double check the segments + new_boundaries = [] + for i in range(boundaries.shape[0]): + beg_sample = int(boundaries[i, 0] * sample_rate) + end_sample = int(boundaries[i, 1] * sample_rate) + len_seg = end_sample - beg_sample + + if not isinstance(audio_file, torch.Tensor): + # Read the candidate speech segment + segment, fs = torchaudio.load( + str(audio_file), frame_offset=beg_sample, num_frames=len_seg + ) + else: + segment = audio_file[:, beg_sample : beg_sample + len_seg] + speech_prob = self.get_speech_prob_chunk(segment) + if speech_prob.mean() > speech_th: + # Accept this as a speech segment + new_boundaries.append([boundaries[i, 0], boundaries[i, 1]]) + + # Convert boundaries from list to tensor + new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device) + return new_boundaries + + def segment_utterance( + self, + segment: typing.Union[Segment, np.ndarray], + apply_energy_VAD: bool = False, + double_check: bool = False, + close_th: float = 0.333, + len_th: float = 0.333, + activation_th: float = 0.5, + deactivation_th: float = 0.25, + en_activation_th: float = 0.5, + en_deactivation_th: float = 0.4, + speech_th: float = 0.5, + allow_empty: bool = True, + ) -> typing.List[Segment]: + if isinstance(segment, Segment): + y = torch.tensor(segment.wave[np.newaxis, :]) + else: + if len(segment.shape) == 1: + y = torch.tensor(segment[np.newaxis, :]) + elif not torch.is_tensor(segment): + y = torch.tensor(segment) + else: + y = segment + prob_chunks = self.get_speech_prob_chunk(y).float() + prob_th = self.apply_threshold( + prob_chunks, + activation_th=activation_th, + deactivation_th=deactivation_th, + ).float() + + # Compute the boundaries of the speech segments + boundaries = self.get_boundaries(prob_th, output_value="seconds").cpu() + if isinstance(segment, Segment) and segment.begin is not None: + boundaries += segment.begin + # Apply energy-based VAD on the detected speech segments + if apply_energy_VAD: + vad_boundaries = self.energy_VAD( + y, + boundaries, + activation_th=en_activation_th, + deactivation_th=en_deactivation_th, + ) + if vad_boundaries.size(0) != 0 or allow_empty: + boundaries = vad_boundaries + + # Merge short segments + boundaries = self.merge_close_segments(boundaries, close_th=close_th) + + # Remove short segments + filtered_boundaries = self.remove_short_segments(boundaries, len_th=len_th) + if filtered_boundaries.size(0) != 0 or allow_empty: + boundaries = filtered_boundaries + + # Double check speech segments + if double_check: + checked_boundaries = self.double_check_speech_segments( + boundaries, y, speech_th=speech_th + ) + if checked_boundaries.size(0) != 0 or allow_empty: + boundaries = checked_boundaries + boundaries[:, 0] -= round(close_th / 2, 3) + boundaries[:, 1] += round(close_th / 2, 3) + segments = [] + for i in range(boundaries.numpy().shape[0]): + begin, end = boundaries[i] + if i == 0: + begin = max(begin, 0) + if i == boundaries.numpy().shape[0] - 1: + end = min( + end, + segment.end + if isinstance(segment, Segment) + else segment.shape[0] / self.sample_rate, + ) + seg = Segment( + segment.file_path if isinstance(segment, Segment) else "", + float(begin), + float(end), + segment.channel if isinstance(segment, Segment) else 0, + ) + segments.append(seg) + return segments + + def segment_for_whisper( + self, + segment: typing.Union[torch.Tensor, np.ndarray], + apply_energy_VAD: bool = False, + close_th: float = 0.333, + len_th: float = 0.333, + activation_th: float = 0.5, + deactivation_th: float = 0.25, + en_activation_th: float = 0.5, + en_deactivation_th: float = 0.4, + **kwargs, + ) -> typing.List[typing.Dict[str, float]]: + if len(segment.shape) == 1: + y = torch.tensor(segment[np.newaxis, :]) + elif not torch.is_tensor(segment): + y = torch.tensor(segment) + else: + y = segment + prob_chunks = self.get_speech_prob_chunk(y).float() + prob_th = self.apply_threshold( + prob_chunks, + activation_th=activation_th, + deactivation_th=deactivation_th, + ).float() + + # Compute the boundaries of the speech segments + boundaries = self.get_boundaries(prob_th, output_value="seconds").cpu() + del prob_chunks + del prob_th + + # Apply energy-based VAD on the detected speech segments + if apply_energy_VAD: + vad_boundaries = self.energy_VAD( + y, + boundaries, + activation_th=en_activation_th, + deactivation_th=en_deactivation_th, + ) + boundaries = vad_boundaries + + # Merge short segments + boundaries = self.merge_close_segments(boundaries, close_th=close_th) + + # Remove short segments + filtered_boundaries = self.remove_short_segments(boundaries, len_th=len_th) + if filtered_boundaries.size(0) != 0: + boundaries = filtered_boundaries + boundaries[:, 0] -= round(close_th / 2, 3) + boundaries[:, 1] += round(close_th / 2, 3) + segments = [] + for i in range(boundaries.numpy().shape[0]): + begin, end = boundaries[i] + if i == 0: + begin = max(begin, 0) + if i == boundaries.numpy().shape[0] - 1: + end = min(end, segment.shape[0] / self.sample_rate) + f1 = int(float(begin) * self.sample_rate) + f2 = int(float(end) * self.sample_rate) + segments.append({"start": float(begin), "end": float(end), "inputs": y[0, f1:f2]}) + return segments + + +class SpeechbrainSegmenterMixin: + def __init__( + self, + apply_energy_vad: bool = False, + double_check: bool = False, + close_th: float = 0.333, + len_th: float = 0.333, + activation_th: float = 0.5, + deactivation_th: float = 0.25, + en_activation_th: float = 0.5, + en_deactivation_th: float = 0.4, + speech_th: float = 0.5, + cuda: bool = False, + **kwargs, + ): + if not FOUND_SPEECHBRAIN: + logger.error( + "Could not import speechbrain, please ensure it is installed via `pip install speechbrain`" + ) + sys.exit(1) + super().__init__(**kwargs) + self.apply_energy_vad = apply_energy_vad + self.double_check = double_check + self.close_th = close_th + self.len_th = len_th + self.activation_th = activation_th + self.deactivation_th = deactivation_th + self.en_activation_th = en_activation_th + self.en_deactivation_th = en_deactivation_th + self.speech_th = speech_th + self.cuda = cuda + self.speechbrain = True + self.vad_model = None + model_dir = os.path.join(config.TEMPORARY_DIRECTORY, "models", "VAD") + os.makedirs(model_dir, exist_ok=True) + run_opts = None + if self.cuda: + run_opts = {"device": "cuda"} + self.vad_model = MfaVAD.from_hparams( + source="speechbrain/vad-crdnn-libriparty", savedir=model_dir, run_opts=run_opts + ) + + @property + def segmentation_options(self) -> MetaDict: + """Options for segmentation""" + return { + "apply_energy_VAD": self.apply_energy_vad, + "double_check": self.double_check, + "activation_th": self.activation_th, + "deactivation_th": self.deactivation_th, + "en_activation_th": self.en_activation_th, + "en_deactivation_th": self.en_deactivation_th, + "speech_th": self.speech_th, + "close_th": self.close_th, + "len_th": self.len_th, + } diff --git a/montreal_forced_aligner/vad/multiprocessing.py b/montreal_forced_aligner/vad/multiprocessing.py index cde7ac90..1037c7da 100644 --- a/montreal_forced_aligner/vad/multiprocessing.py +++ b/montreal_forced_aligner/vad/multiprocessing.py @@ -1,13 +1,11 @@ """Multiprocessing functionality for VAD""" from __future__ import annotations -import logging import typing from pathlib import Path from typing import TYPE_CHECKING, List, Union import numpy -import numpy as np import pynini import pywrapfst from _kalpy.decoder import LatticeFasterDecoder, LatticeFasterDecoderConfig @@ -30,132 +28,7 @@ from montreal_forced_aligner.db import File, Job, Speaker, Utterance from montreal_forced_aligner.exceptions import SegmenterError from montreal_forced_aligner.models import AcousticModel, G2PModel - -try: - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - torch_logger = logging.getLogger("speechbrain.utils.torch_audio_backend") - torch_logger.setLevel(logging.ERROR) - torch_logger = logging.getLogger("speechbrain.utils.train_logger") - torch_logger.setLevel(logging.ERROR) - import torch - import torchaudio - - try: - from speechbrain.pretrained import VAD - except ImportError: # speechbrain 1.0 - from speechbrain.inference.VAD import VAD - - class MfaVAD(VAD): - def energy_VAD( - self, - audio_file: typing.Union[str, Path, np.ndarray], - boundaries, - activation_th=0.5, - deactivation_th=0.0, - eps=1e-6, - ): - """Applies energy-based VAD within the detected speech segments.The neural - network VAD often creates longer segments and tends to merge segments that - are close with each other. - - The energy VAD post-processes can be useful for having a fine-grained voice - activity detection. - - The energy VAD computes the energy within the small chunks. The energy is - normalized within the segment to have mean 0.5 and +-0.5 of std. - This helps to set the energy threshold. - - Arguments - --------- - audio_file: path - Path of the audio file containing the recording. The file is read - with torchaudio. - boundaries: torch.Tensor - torch.Tensor containing the speech boundaries. It can be derived using the - get_boundaries method. - activation_th: float - A new speech segment is started it the energy is above activation_th. - deactivation_th: float - The segment is considered ended when the energy is <= deactivation_th. - eps: float - Small constant for numerical stability. - - Returns - ------- - new_boundaries - The new boundaries that are post-processed by the energy VAD. - """ - if not isinstance(audio_file, np.ndarray): - # Getting the total size of the input file - sample_rate, audio_len = self._get_audio_info(audio_file) - - if sample_rate != self.sample_rate: - raise ValueError( - "The detected sample rate is different from that set in the hparam file" - ) - else: - sample_rate = self.sample_rate - - # Computing the chunk length of the energy window - chunk_len = int(self.time_resolution * sample_rate) - new_boundaries = [] - - # Processing speech segments - for i in range(boundaries.shape[0]): - begin_sample = int(boundaries[i, 0] * sample_rate) - end_sample = int(boundaries[i, 1] * sample_rate) - seg_len = end_sample - begin_sample - - if not isinstance(audio_file, np.ndarray): - # Reading the speech segment - segment, _ = torchaudio.load( - audio_file, frame_offset=begin_sample, num_frames=seg_len - ) - else: - segment = audio_file[begin_sample : begin_sample + seg_len] - - # Create chunks - segment_chunks = self.create_chunks( - segment, chunk_size=chunk_len, chunk_stride=chunk_len - ) - - # Energy computation within each chunk - energy_chunks = segment_chunks.abs().sum(-1) + eps - energy_chunks = energy_chunks.log() - - # Energy normalization - energy_chunks = ( - (energy_chunks - energy_chunks.mean()) / (2 * energy_chunks.std()) - ) + 0.5 - energy_chunks = energy_chunks.unsqueeze(0).unsqueeze(2) - - # Apply threshold based on the energy value - energy_vad = self.apply_threshold( - energy_chunks, - activation_th=activation_th, - deactivation_th=deactivation_th, - ) - - # Get the boundaries - energy_boundaries = self.get_boundaries(energy_vad, output_value="seconds") - - # Get the final boundaries in the original signal - for j in range(energy_boundaries.shape[0]): - start_en = boundaries[i, 0] + energy_boundaries[j, 0] - end_end = boundaries[i, 0] + energy_boundaries[j, 1] - new_boundaries.append([start_en, end_end]) - - # Convert boundaries to tensor - new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device) - return new_boundaries - - FOUND_SPEECHBRAIN = True -except (ImportError, OSError): - FOUND_SPEECHBRAIN = False - MfaVAD = None +from montreal_forced_aligner.vad.models import MfaVAD if TYPE_CHECKING: SpeakerCharacterType = Union[str, int] @@ -174,7 +47,6 @@ def energy_VAD( "merge_segments", "segment_utterance_transcript", "segment_utterance_vad", - "segment_utterance_vad_speech_brain", ] @@ -191,7 +63,7 @@ class SegmentTranscriptArguments(MfaArguments): """Arguments for :class:`~montreal_forced_aligner.segmenter.SegmentTranscriptFunction`""" acoustic_model: AcousticModel - vad_model: typing.Optional[VAD] + vad_model: typing.Optional[MfaVAD] lexicon_compilers: typing.Dict[int, LexiconCompiler] mfcc_options: MetaDict vad_options: MetaDict @@ -231,8 +103,8 @@ def segment_utterance( if vad_model is None: segments = segment_utterance_vad(segment, mfcc_options, vad_options, segmentation_options) else: - segments = segment_utterance_vad_speech_brain( - segment, vad_model, segmentation_options, allow_empty=allow_empty + segments = vad_model.segment_utterance( + segment, **segmentation_options, allow_empty=allow_empty ) if not segments: return [segment] @@ -577,68 +449,6 @@ def segment_utterance_vad( return new_segments -def segment_utterance_vad_speech_brain( - segment: Segment, - vad_model: MfaVAD, - segmentation_options: MetaDict, - allow_empty: bool = True, -) -> typing.List[Segment]: - y = segment.wave - prob_chunks = vad_model.get_speech_prob_chunk(torch.tensor(y[np.newaxis, :])).float() - prob_th = vad_model.apply_threshold( - prob_chunks, - activation_th=segmentation_options["activation_th"], - deactivation_th=segmentation_options["deactivation_th"], - ).float() - - # Compute the boundaries of the speech segments - boundaries = vad_model.get_boundaries(prob_th, output_value="seconds").cpu() - if segment.begin is not None: - boundaries += segment.begin - # Apply energy-based VAD on the detected speech segments - if segmentation_options["apply_energy_VAD"]: - vad_boundaries = vad_model.energy_VAD( - segment.file_path, - boundaries, - activation_th=segmentation_options["en_activation_th"], - deactivation_th=segmentation_options["en_deactivation_th"], - ) - if vad_boundaries.size(0) != 0 or allow_empty: - boundaries = vad_boundaries - - # Merge short segments - boundaries = vad_model.merge_close_segments( - boundaries, close_th=segmentation_options["close_th"] - ) - - # Remove short segments - filtered_boundaries = vad_model.remove_short_segments( - boundaries, len_th=segmentation_options["len_th"] - ) - if filtered_boundaries.size(0) != 0 or allow_empty: - boundaries = filtered_boundaries - - # Double check speech segments - if segmentation_options["double_check"]: - checked_boundaries = vad_model.double_check_speech_segments( - boundaries, segment.file_path, speech_th=segmentation_options["speech_th"] - ) - if checked_boundaries.size(0) != 0 or allow_empty: - boundaries = checked_boundaries - boundaries[:, 0] -= round(segmentation_options["close_th"] / 2, 3) - boundaries[:, 1] += round(segmentation_options["close_th"] / 2, 3) - segments = [] - for i in range(boundaries.numpy().shape[0]): - begin, end = boundaries[i] - if i == 0: - begin = max(begin, 0) - if i == boundaries.numpy().shape[0] - 1: - end = min(end, segment.end) - seg = Segment(segment.file_path, float(begin), float(end), segment.channel) - segments.append(seg) - return segments - - class SegmentVadFunction(KaldiFunction): """ Multiprocessing function to generate segments from VAD output. diff --git a/montreal_forced_aligner/vad/segmenter.py b/montreal_forced_aligner/vad/segmenter.py index be960e9b..27cb0f06 100644 --- a/montreal_forced_aligner/vad/segmenter.py +++ b/montreal_forced_aligner/vad/segmenter.py @@ -8,7 +8,6 @@ import collections import logging import os -import sys import typing from pathlib import Path from typing import Dict, List, Optional @@ -37,9 +36,8 @@ from montreal_forced_aligner.tokenization.spacy import generate_language_tokenizer from montreal_forced_aligner.transcription.transcriber import TranscriberMixin from montreal_forced_aligner.utils import log_kaldi_errors, run_kaldi_function +from montreal_forced_aligner.vad.models import SpeechbrainSegmenterMixin from montreal_forced_aligner.vad.multiprocessing import ( - FOUND_SPEECHBRAIN, - MfaVAD, SegmentTranscriptArguments, SegmentTranscriptFunction, SegmentVadArguments, @@ -50,81 +48,11 @@ SegmentationType = List[Dict[str, float]] -__all__ = ["VadSegmenter", "SpeechbrainSegmenterMixin", "TranscriptionSegmenter"] +__all__ = ["VadSegmenter", "TranscriptionSegmenter"] logger = logging.getLogger("mfa") -class SpeechbrainSegmenterMixin: - def __init__( - self, - segment_padding: float = 0.1, - large_chunk_size: float = 30, - small_chunk_size: float = 0.05, - overlap_small_chunk: bool = False, - apply_energy_vad: bool = False, - double_check: bool = True, - close_th: float = 0.333, - len_th: float = 0.333, - activation_th: float = 0.5, - deactivation_th: float = 0.25, - en_activation_th: float = 0.5, - en_deactivation_th: float = 0.4, - speech_th: float = 0.5, - cuda: bool = False, - **kwargs, - ): - if not FOUND_SPEECHBRAIN: - logger.error( - "Could not import speechbrain, please ensure it is installed via `pip install speechbrain`" - ) - sys.exit(1) - super().__init__(**kwargs) - self.large_chunk_size = large_chunk_size - self.small_chunk_size = small_chunk_size - self.overlap_small_chunk = overlap_small_chunk - self.apply_energy_vad = apply_energy_vad - self.double_check = double_check - self.close_th = close_th - self.len_th = len_th - self.activation_th = activation_th - self.deactivation_th = deactivation_th - self.en_activation_th = en_activation_th - self.en_deactivation_th = en_deactivation_th - self.speech_th = speech_th - self.cuda = cuda - self.speechbrain = True - self.segment_padding = segment_padding - self.vad_model = None - model_dir = os.path.join(config.TEMPORARY_DIRECTORY, "models", "VAD") - os.makedirs(model_dir, exist_ok=True) - run_opts = None - if self.cuda: - run_opts = {"device": "cuda"} - self.vad_model = MfaVAD.from_hparams( - source="speechbrain/vad-crdnn-libriparty", savedir=model_dir, run_opts=run_opts - ) - - @property - def segmentation_options(self) -> MetaDict: - """Options for segmentation""" - return { - "large_chunk_size": self.large_chunk_size, - "frame_shift": getattr(self, "export_frame_shift", 0.01), - "small_chunk_size": self.small_chunk_size, - "overlap_small_chunk": self.overlap_small_chunk, - "apply_energy_VAD": self.apply_energy_vad, - "double_check": self.double_check, - "activation_th": self.activation_th, - "deactivation_th": self.deactivation_th, - "en_activation_th": self.en_activation_th, - "en_deactivation_th": self.en_deactivation_th, - "speech_th": self.speech_th, - "close_th": self.close_th, - "len_th": self.len_th, - } - - class VadSegmenter( VadConfigMixin, AcousticCorpusMixin, @@ -242,13 +170,22 @@ def segment_vad_arguments(self) -> List[SegmentVadArguments]: list[SegmentVadArguments] Arguments for processing """ + options = self.segmentation_options + options.update( + { + "large_chunk_size": self.large_chunk_size, + "frame_shift": getattr(self, "export_frame_shift", 0.01), + "small_chunk_size": self.small_chunk_size, + "overlap_small_chunk": self.overlap_small_chunk, + } + ) return [ SegmentVadArguments( j.id, getattr(self, "session" if config.USE_THREADING else "db_string", ""), self.working_log_directory.joinpath(f"segment_vad.{j.id}.log"), j.construct_path(self.split_directory, "vad", "scp"), - self.segmentation_options, + options, ) for j in self.jobs ] @@ -268,7 +205,6 @@ def segment_vad_speechbrain(self) -> None: old_utts = set() new_utts = [] kwargs = self.segmentation_options - kwargs.pop("frame_shift") with tqdm( total=self.num_utterances, disable=config.QUIET ) as pbar, self.session() as session: @@ -288,8 +224,8 @@ def segment_vad_speechbrain(self) -> None: for i in range(boundaries.shape[0]): old_utts.add(u.id) begin, end = boundaries[i, :] - begin -= round(kwargs["close_th"] / 2, 3) - end += round(kwargs["close_th"] / 2, 3) + begin -= round(self.close_th / 2, 3) + end += round(self.close_th / 2, 3) if i == 0: begin = max(0.0, begin) if i == boundaries.shape[0] - 1: diff --git a/tests/test_commandline_transcribe.py b/tests/test_commandline_transcribe.py index 88cc7b31..88aecb06 100644 --- a/tests/test_commandline_transcribe.py +++ b/tests/test_commandline_transcribe.py @@ -4,11 +4,8 @@ import pytest from montreal_forced_aligner.command_line.mfa import mfa_cli -from montreal_forced_aligner.transcription.multiprocessing import ( - FOUND_FASTER_WHISPER, - FOUND_SPEECHBRAIN, - FOUND_TRANSFORMERS, -) +from montreal_forced_aligner.transcription.models import FOUND_WHISPERX +from montreal_forced_aligner.transcription.multiprocessing import FOUND_SPEECHBRAIN def test_transcribe( @@ -97,10 +94,8 @@ def test_transcribe_whisper( temp_dir, db_setup, ): - if not FOUND_TRANSFORMERS: - pytest.skip("transformers not installed") - if not FOUND_FASTER_WHISPER: - pytest.skip("faster-whisper not installed") + if not FOUND_WHISPERX: + pytest.skip("whisperx not installed") output_path = generated_dir.joinpath("transcribe_test_whisper") command = [ "transcribe_whisper", From 5dfbda29e85d1f49fc038db530ad7d66bb59578b Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Thu, 19 Sep 2024 18:13:57 -0700 Subject: [PATCH 05/16] 3P transcription generally working --- bin/mfa_update | 52 ++- docs/source/changelog/changelog_3.0.rst | 7 + environment.yml | 8 +- montreal_forced_aligner/abc.py | 7 +- montreal_forced_aligner/corpus/classes.py | 3 + montreal_forced_aligner/data.py | 244 +++++------ montreal_forced_aligner/db.py | 30 ++ montreal_forced_aligner/dictionary/mixins.py | 7 +- .../dictionary/multispeaker.py | 10 +- montreal_forced_aligner/g2p/generator.py | 22 +- .../g2p/phonetisaurus_trainer.py | 12 +- montreal_forced_aligner/g2p/trainer.py | 19 +- montreal_forced_aligner/models.py | 9 +- .../tokenization/simple.py | 3 +- .../transcription/models.py | 43 +- .../transcription/multiprocessing.py | 3 - .../transcription/transcriber.py | 35 +- montreal_forced_aligner/utils.py | 2 + montreal_forced_aligner/vad/models.py | 381 ++++++++++++------ .../vad/multiprocessing.py | 95 +---- montreal_forced_aligner/vad/segmenter.py | 18 +- tests/test_commandline_align.py | 10 +- tests/test_commandline_transcribe.py | 1 - tests/test_corpus.py | 21 +- 24 files changed, 618 insertions(+), 424 deletions(-) diff --git a/bin/mfa_update b/bin/mfa_update index b6a03f2d..f7d2d4a2 100644 --- a/bin/mfa_update +++ b/bin/mfa_update @@ -1,18 +1,50 @@ #!/usr/bin/env python +import argparse import os import shutil import subprocess +import sys from importlib.util import find_spec -anchor_found = find_spec("anchor") is not None +if __name__ == "__main__": + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--install_3p", + action="store_true", + help="Install/update third party dependencies (Speechbrain and WhisperX)", + ) + args = parser.parse_args() + anchor_found = find_spec("anchor") is not None + speechbrain_found = find_spec("speechbrain") is not None + whisperx_found = find_spec("whisperx") is not None -conda_path = shutil.which("conda") -mamba_path = shutil.which("mamba") -if mamba_path is None: - print("No mamba found, installing first...") - subprocess.call([conda_path, "install", "-c", "conda-forge", "-y", "mamba"], env=os.environ) -package_list = ["montreal-forced-aligner", "kalpy", "kaldi=*=cpu*"] -if anchor_found: - package_list.append("anchor-annotator") -subprocess.call([mamba_path, "update", "-c", "conda-forge", "-y"] + package_list, env=os.environ) + conda_path = shutil.which("conda") + if conda_path is None: + print("Please install conda before running this command.") + sys.exit(1) + mamba_path = shutil.which("mamba") + if mamba_path is None: + print("No mamba found, installing first...") + subprocess.call( + [conda_path, "install", "-c", "conda-forge", "-y", "mamba"], env=os.environ + ) + package_list = ["montreal-forced-aligner", "kalpy", "kaldi=*=cpu*"] + if anchor_found: + package_list.append("anchor-annotator") + subprocess.call( + [mamba_path, "update", "-c", "conda-forge", "-y"] + package_list, env=os.environ + ) + if args.install_3p: + channels = ["conda-forge", "pytorch", "nvidia", "anaconda"] + package_list = ["pytorch", "torchaudio"] + if not whisperx_found: + package_list.extend(["cudnn=8", "transformers"]) + command = [mamba_path, "install", "-y"] + for c in channels: + command.extend(["-c", c]) + command += package_list + subprocess.call(command, env=os.environ) + command = ["pip", "install", "-U"] + package_list = ["whisperx", "speechbrain", "pygtrie"] + subprocess.call(command, env=os.environ) diff --git a/docs/source/changelog/changelog_3.0.rst b/docs/source/changelog/changelog_3.0.rst index 5a789434..74903e4a 100644 --- a/docs/source/changelog/changelog_3.0.rst +++ b/docs/source/changelog/changelog_3.0.rst @@ -5,6 +5,13 @@ 3.0 Changelog ************* +3.2.0 +----- + +- Add support for transcription via whisperx and speechbrain models +- Update text normalization to normalize to decomposed forms +- Compatibility with Kalpy 0.6.7 + 3.1.4 ----- diff --git a/environment.yml b/environment.yml index 826d4ef4..c442fd11 100644 --- a/environment.yml +++ b/environment.yml @@ -29,8 +29,6 @@ dependencies: - postgresql - psycopg2 - click - - pytorch - - torchaudio - setuptools_scm - pytest - pytest-mypy @@ -52,12 +50,11 @@ dependencies: - sudachipy - sudachidict-core - spacy-pkuseg + - pytorch + - torchaudio # WhisperX dependencies - cudnn =8 - transformers - - tokenizers - - accelerate - - tiktoken - pip: - build - twine @@ -69,7 +66,6 @@ dependencies: - dragonmapper # Speechbrain dependencies - speechbrain - - kenlm - pygtrie # WhisperX dependencies - whisperx diff --git a/montreal_forced_aligner/abc.py b/montreal_forced_aligner/abc.py index 80666584..f2b34a8d 100644 --- a/montreal_forced_aligner/abc.py +++ b/montreal_forced_aligner/abc.py @@ -618,6 +618,8 @@ def parse_args( dict[str, Any] Dictionary of specified configuration parameters """ + from montreal_forced_aligner.data import Language + param_types = cls.get_configuration_parameters() params = {} unknown_dict = {} @@ -640,7 +642,10 @@ def parse_args( ): continue if args is not None and name in args and args[name] is not None: - params[name] = param_type(args[name]) + if param_type == Language: + params[name] = param_type[args[name]] + else: + params[name] = param_type(args[name]) elif name in unknown_dict: params[name] = param_type(unknown_dict[name]) if param_type == bool and not isinstance(unknown_dict[name], bool): diff --git a/montreal_forced_aligner/corpus/classes.py b/montreal_forced_aligner/corpus/classes.py index 62f32427..31a34a90 100644 --- a/montreal_forced_aligner/corpus/classes.py +++ b/montreal_forced_aligner/corpus/classes.py @@ -5,6 +5,7 @@ import sys import traceback import typing +import unicodedata from typing import TYPE_CHECKING, Optional, Union from praatio import textgrid @@ -135,6 +136,7 @@ def load_text( if self.text_type == TextFileType.LAB: try: text = load_text(self.text_path) + text = unicodedata.normalize("NFKC", text) except UnicodeDecodeError: raise TextParseError(self.text_path) if self.wav_info is None: @@ -190,6 +192,7 @@ def load_text( text = text.strip() if not text: continue + text = unicodedata.normalize("NFKC", text) begin, end = round(begin, 4), round(end, 4) if begin >= duration: continue diff --git a/montreal_forced_aligner/data.py b/montreal_forced_aligner/data.py index 2b9f2eca..f3ff12e2 100644 --- a/montreal_forced_aligner/data.py +++ b/montreal_forced_aligner/data.py @@ -475,152 +475,128 @@ class ClusterType(enum.Enum): meanshift = "meanshift" -ISO_LANGUAGE_MAPPING = { - "afrikaans": "af", - "amharic": "am", - "arabic": "ar", - "assamese": "as", - "azerbaijani": "az", - "bashkir": "ba", - "belarusian": "be", - "bulgarian": "bg", - "bengali": "bn", - "tibetan": "bo", - "breton": "br", - "bosnian": "bs", - "catalan": "ca", - "czech": "cs", - "welsh": "cy", - "danish": "da", - "german": "de", - "greek": "el", - "english": "en", - "spanish": "es", - "estonian": "et", - "basque": "eu", - "farsi": "fa", - "finnish": "fi", - "faroese": "fo", - "french": "fr", - "galician": "gl", - "gujarati": "gu", - "hausa": "ha", - "hebrew": "he", - "hindi": "hi", - "croatian": "hr", - "haitian": "ht", - "hungarian": "hu", - "armenian": "hy", - "indonesian": "id", - "icelandic": "is", - "italian": "it", - "japanese": "ja", - "georgian": "ka", - "kazakh": "kk", - "central khmer": "km", - "kannada": "kn", - "korean": "ko", - "latin": "la", - "luxembourgish": "lb", - "lingala": "ln", - "lao": "lo", - "lithuanian": "lt", - "latvian": "lv", - "malagasy": "mg", - "maori": "mi", - "macedonian": "mk", - "malayalam": "ml", - "mongolian": "mn", - "marathi": "mr", - "malay": "ms", - "maltese": "mt", - "burmese": "my", - "nepali": "ne", - "dutch": "nl", - "flemish": "nl", - "norwegian nynorsk": "nn", - "norwegian": "no", - "occitan": "oc", - "punjabi": "pa", - "polish": "pl", - "pashto": "ps", - "portuguese": "pt", - "romanian": "ro", - "moldavian": "ro", - "russian": "ru", - "sanskrit": "sa", - "sindhi": "sd", - "sinhala": "si", - "slovak": "sk", - "slovenian": "sl", - "shona": "sn", - "somali": "so", - "albanian": "sq", - "serbian": "sr", - "sundanese": "su", - "swedish": "sv", - "swahili": "sw", - "tamil": "ta", - "telegu": "te", - "tajik": "tg", - "thai": "th", - "turkmen": "tk", - "tagalog": "tl", - "turkish": "tr", - "tatar": "tt", - "ukrainian": "uk", - "urdu": "ur", - "uzbek": "uz", - "vietnamese": "vi", - "yiddish": "yi", - "yoruba": "yo", - "yue": "yue", - "chinese": "zh", - "kinyarwanda": "rw", - "mandarin": "zh-CN", -} +ISO_LANGUAGE_MAPPING = {} class Language(enum.Enum): """Enum for supported languages""" - unknown = "unknown" - catalan = "catalan" - chinese = "chinese" - croatian = "croatian" - danish = "danish" - dutch = "dutch" - english = "english" - finnish = "finnish" - french = "french" - german = "german" - greek = "greek" - italian = "italian" - japanese = "japanese" - korean = "korean" - lithuanian = "lithuanian" - macedonian = "macedonian" - multilingual = "multilingual" - norwegian = "norwegian" - polish = "polish" - portuguese = "portuguese" - romanian = "romanian" - russian = "russian" - slovenian = "slovenian" - spanish = "spanish" - swedish = "swedish" - thai = "thai" - ukrainian = "ukrainian" + unknown = None + afrikaans = "af" + amharic = "am" + arabic = "ar" + assamese = "as" + azerbaijani = "az" + bashkir = "ba" + belarusian = "be" + bulgarian = "bg" + bengali = "bn" + tibetan = "bo" + breton = "br" + bosnian = "bs" + catalan = "ca" + czech = "cs" + welsh = "cy" + danish = "da" + german = "de" + greek = "el" + english = "en" + spanish = "es" + estonian = "et" + basque = "eu" + farsi = "fa" + finnish = "fi" + faroese = "fo" + french = "fr" + galician = "gl" + gujarati = "gu" + hausa = "ha" + hebrew = "he" + hindi = "hi" + croatian = "hr" + haitian = "ht" + hungarian = "hu" + armenian = "hy" + indonesian = "id" + icelandic = "is" + italian = "it" + japanese = "ja" + georgian = "ka" + kazakh = "kk" + central_khmer = "km" + kannada = "kn" + korean = "ko" + latin = "la" + luxembourgish = "lb" + lingala = "ln" + lao = "lo" + lithuanian = "lt" + latvian = "lv" + malagasy = "mg" + maori = "mi" + macedonian = "mk" + malayalam = "ml" + mongolian = "mn" + marathi = "mr" + malay = "ms" + maltese = "mt" + burmese = "my" + nepali = "ne" + dutch = "nl" + flemish = "nl" + norwegian_nynorsk = "nn" + norwegian = "no" + occitan = "oc" + punjabi = "pa" + polish = "pl" + pashto = "ps" + portuguese = "pt" + romanian = "ro" + moldavian = "ro" + russian = "ru" + sanskrit = "sa" + sindhi = "sd" + sinhala = "si" + slovak = "sk" + slovenian = "sl" + shona = "sn" + somali = "so" + albanian = "sq" + serbian = "sr" + sundanese = "su" + swedish = "sv" + swahili = "sw" + tamil = "ta" + telegu = "te" + tajik = "tg" + thai = "th" + turkmen = "tk" + tagalog = "tl" + turkish = "tr" + tatar = "tt" + ukrainian = "uk" + urdu = "ur" + uzbek = "uz" + vietnamese = "vi" + yiddish = "yi" + yoruba = "yo" + yue = "yue" + chinese = "zh" + kinyarwanda = "rw" + mandarin = "zh-CN" + multilingual = None def __str__(self) -> str: """Name of phone set""" return self.name + @property + def display_name(self): + return self.name.replace("_", " ").title() + @property def iso_code(self) -> typing.Optional[str]: - if self.value in ISO_LANGUAGE_MAPPING: - return ISO_LANGUAGE_MAPPING[self.value] - return None + return self.value class ManifoldAlgorithm(enum.Enum): diff --git a/montreal_forced_aligner/db.py b/montreal_forced_aligner/db.py index 801d0ad9..dd649617 100644 --- a/montreal_forced_aligner/db.py +++ b/montreal_forced_aligner/db.py @@ -1202,6 +1202,8 @@ def save( ) ) else: + if utterance.end < utterance.begin: + utterance.begin, utterance.end = utterance.end, utterance.begin if tiers[utterance.speaker.name].entries: if tiers[utterance.speaker.name].entries[-1].end > utterance.begin: utterance.begin = tiers[utterance.speaker.name].entries[-1].end @@ -1318,6 +1320,34 @@ def normalized_waveform( x = np.linspace(start=begin, stop=end, num=num_steps) return x, y + def load_audio( + self, begin: float = 0, end: typing.Optional[float] = None + ) -> typing.Tuple[np.array, np.array]: + """ + Load a normalized waveform for acoustic processing/visualization + + Parameters + ---------- + begin: float, optional + Starting time point to return, defaults to 0 + end: float, optional + Ending time point to return, defaults to the end of the file + + Returns + ------- + numpy.array + Time points + numpy.array + Sample values + """ + if end is None or end > self.duration: + end = self.duration + + y, _ = librosa.load( + self.sound_file_path, sr=16000, mono=False, offset=begin, duration=end - begin + ) + return y + class TextFile(MfaSqlBase): """ diff --git a/montreal_forced_aligner/dictionary/mixins.py b/montreal_forced_aligner/dictionary/mixins.py index 46743e4e..a986ec03 100644 --- a/montreal_forced_aligner/dictionary/mixins.py +++ b/montreal_forced_aligner/dictionary/mixins.py @@ -355,6 +355,7 @@ def specials_set(self) -> Set[str]: self.oov_word, self.bracketed_word, self.laughter_word, + self.cutoff_word, "", "", } @@ -729,7 +730,7 @@ def _write_topo(self) -> None: for i in range(num_states): if i == 0: # Initial non_silence state if min_states == max_states: - transition_string = f" {i} 0.5 {i+1} 0.5" + transition_string = f" {i} 0.5 {i + 1} 0.5" else: transition_probability = 1 / max_states transition_string = " ".join( @@ -741,11 +742,11 @@ def _write_topo(self) -> None: ) elif i == num_states - 1: non_silence_lines.append( - f" {i} {i} {i+1} 1.0 " + f" {i} {i} {i + 1} 1.0 " ) else: non_silence_lines.append( - f" {i} {i} {i} 0.5 {i+1} 0.5 " + f" {i} {i} {i} 0.5 {i + 1} 0.5 " ) non_silence_lines.append(f" {num_states} ") non_silence_lines.append("") diff --git a/montreal_forced_aligner/dictionary/multispeaker.py b/montreal_forced_aligner/dictionary/multispeaker.py index e6a7e37b..0eac9db8 100644 --- a/montreal_forced_aligner/dictionary/multispeaker.py +++ b/montreal_forced_aligner/dictionary/multispeaker.py @@ -8,6 +8,7 @@ import os import re import typing +import unicodedata from pathlib import Path from typing import Dict, Optional, Tuple @@ -576,8 +577,11 @@ def dictionary_setup(self) -> Tuple[typing.Set[str], collections.Counter]: word = clitic_cleanup_regex.sub(self.clitic_marker, word) if word in self.specials_set: continue - characters = list(word) if word not in special_words: + if getattr(self, "unicode_decomposition", False): + characters = unicodedata.normalize("NFKD", word) + else: + characters = word graphemes.update(characters) if pretrained: difference = set(pron) - self.non_silence_phones - self.silence_phones @@ -861,6 +865,8 @@ def calculate_disambiguation(self) -> None: dictionaries = session.query(Dictionary) update_pron_objs = [] for d in dictionaries: + if d.name == "default": + continue subsequences = set() words = ( session.query(Word) @@ -871,7 +877,7 @@ def calculate_disambiguation(self) -> None: for w in words: for p in w.pronunciations: pron = p.pronunciation.split() - while pron: + while len(pron) > 0: subsequences.add(tuple(pron)) pron = pron[:-1] last_used = collections.defaultdict(int) diff --git a/montreal_forced_aligner/g2p/generator.py b/montreal_forced_aligner/g2p/generator.py index eba43f1b..c89a47e0 100644 --- a/montreal_forced_aligner/g2p/generator.py +++ b/montreal_forced_aligner/g2p/generator.py @@ -11,6 +11,7 @@ import statistics import time import typing +import unicodedata from multiprocessing.pool import ThreadPool from pathlib import Path from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union @@ -176,11 +177,13 @@ def __init__( threshold: float = 1, graphemes: Set[str] = None, strict: bool = False, + unicode_decomposition: bool = False, ): self.graphemes = graphemes self.grapheme_symbol_table = grapheme_symbol_table self.phone_symbol_table = phone_symbol_table self.strict = strict + self.unicode_decomposition = unicode_decomposition if num_pronunciations > 0: self.rewrite = functools.partial( scored_top_rewrites, @@ -198,7 +201,7 @@ def __init__( output_token_type=self.phone_symbol_table, ) - def create_word_fst(self, word: str) -> pynini.Fst: + def create_word_fst(self, word: str) -> typing.Optional[pynini.Fst]: if self.graphemes is not None: if self.strict and any(x not in self.graphemes for x in word): return None @@ -208,6 +211,8 @@ def create_word_fst(self, word: str) -> pynini.Fst: def __call__(self, graphemes: str) -> List[str]: # pragma: no cover """Call the rewrite function""" + if self.unicode_decomposition: + graphemes = unicodedata.normalize("NFKD", graphemes) if " " in graphemes: words = graphemes.split() hypotheses = [] @@ -264,13 +269,20 @@ def __init__( sequence_separator: str = "|", graphemes: Set[str] = None, strict: bool = False, + unicode_decomposition: bool = False, ): super().__init__( - fst, grapheme_symbol_table, phone_symbol_table, num_pronunciations, threshold, strict + fst, + grapheme_symbol_table, + phone_symbol_table, + num_pronunciations, + threshold, + graphemes, + strict, + unicode_decomposition, ) self.sequence_separator = sequence_separator self.grapheme_order = grapheme_order - self.graphemes = graphemes def create_word_fst(self, word: str) -> typing.Optional[pynini.Fst]: if self.graphemes is not None: @@ -798,8 +810,8 @@ def compute_validation_errors( hyp_pron_count += len(hyp) gold_pron_count += len(gold_pronunciations) logger.debug( - f"Generated an average of {hyp_pron_count /len(hypothesis_values)} variants " - f"The gold set had an average of {gold_pron_count/len(hypothesis_values)} variants." + f"Generated an average of {hyp_pron_count / len(hypothesis_values)} variants " + f"The gold set had an average of {gold_pron_count / len(hypothesis_values)} variants." ) with ThreadPool(config.NUM_JOBS) as pool: gen = pool.starmap(score_g2p, to_comp) diff --git a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py index c0c466ed..08ea6639 100644 --- a/montreal_forced_aligner/g2p/phonetisaurus_trainer.py +++ b/montreal_forced_aligner/g2p/phonetisaurus_trainer.py @@ -8,6 +8,7 @@ import subprocess import threading import time +import unicodedata from pathlib import Path from queue import Queue @@ -1382,8 +1383,8 @@ def compute_initial_ngrams(self) -> None: thirdparty_binary("ngramsymbols"), "--OOV_symbol=", "--epsilon_symbol=", - input_path, - input_symbols_path, + str(input_path), + str(input_symbols_path), ], encoding="utf8", stderr=log_file, @@ -1623,6 +1624,7 @@ def meta(self) -> MetaDict: "grapheme_order": self.grapheme_order, "phone_order": self.phone_order, "sequence_separator": self.sequence_separator, + "unicode_decomposition": self.unicode_decomposition, "evaluation": {}, "training": { "num_words": self.g2p_num_training_words, @@ -1740,7 +1742,8 @@ def initialize_training(self) -> None: .filter(Word2Job.training == True) # noqa ) for pronunciation, word in query: - word = list(word) + if self.unicode_decomposition: + word = unicodedata.normalize("NFKD", word) self.g2p_training_graphemes.update(word) self.g2p_training_phones.update(pronunciation.split()) @@ -1809,7 +1812,8 @@ def initialize_training(self) -> None: self.working_directory.joinpath("output.txt"), "w" ) as phone_f: for pronunciation, word in query: - word = list(word) + if self.unicode_decomposition: + word = unicodedata.normalize("NFKD", word) grapheme_count += len(word) self.g2p_training_graphemes.update(word) self.g2p_num_training_pronunciations += 1 diff --git a/montreal_forced_aligner/g2p/trainer.py b/montreal_forced_aligner/g2p/trainer.py index 5950f146..413db3f0 100644 --- a/montreal_forced_aligner/g2p/trainer.py +++ b/montreal_forced_aligner/g2p/trainer.py @@ -13,6 +13,7 @@ import threading import time import typing +import unicodedata from pathlib import Path from queue import Queue from typing import Any, List, NamedTuple, Set @@ -204,12 +205,14 @@ def __init__( validation_proportion: float = 0.1, num_pronunciations: int = 0, evaluation_mode: bool = False, + unicode_decomposition: bool = True, **kwargs, ): super().__init__(**kwargs) self.evaluation_mode = evaluation_mode self.validation_proportion = validation_proportion self.num_pronunciations = num_pronunciations + self.unicode_decomposition = unicode_decomposition self.g2p_training_dictionary = {} self.g2p_validation_dictionary = None self.g2p_training_graphemes = set() @@ -482,7 +485,7 @@ def _lexicon_covering(self, input_path=None, output_path=None) -> None: else: com.append("--token_type=utf8") com.extend([input_path, self.input_far_path]) - print(" ".join(map(str, com)), file=log_file) + log_file.write(f'{" ".join(map(str, com))}\n') subprocess.check_call(com, env=os.environ, stderr=log_file, stdout=log_file) com = [ thirdparty_binary("farcompilestrings"), @@ -492,12 +495,12 @@ def _lexicon_covering(self, input_path=None, output_path=None) -> None: output_path, self.output_far_path, ] - print(" ".join(map(str, com)), file=log_file) + log_file.write(f'{" ".join(map(str, com))}\n') subprocess.check_call(com, env=os.environ, stderr=log_file, stdout=log_file) ilabels = _get_far_labels(self.input_far_path) - print(ilabels, file=log_file) + log_file.write(f"{ilabels}\n") olabels = _get_far_labels(self.output_far_path) - print(olabels, file=log_file) + log_file.write(f"{olabels}\n") cg = pywrapfst.VectorFst() state = cg.add_state() cg.set_start(state) @@ -719,6 +722,7 @@ def meta(self) -> MetaDict: "train_date": str(datetime.now()), "phones": sorted(self.non_silence_phones), "graphemes": self.g2p_training_graphemes, + "unicode_decomposition": self.unicode_decomposition, "evaluation": {}, "training": { "num_words": len(self.g2p_training_dictionary), @@ -776,15 +780,18 @@ def initialize_training(self) -> None: for word, pronunciations in self.g2p_training_dictionary.items(): if re.match(r"\W", word) is not None: continue + if self.unicode_decomposition: + word = unicodedata.normalize("NFKD", word) self.g2p_training_graphemes.update(word) for p in pronunciations: self.g2p_training_phones.update(p.split()) - print(word, file=inf) - print(p, file=outf) + inf.write(f"{word}\n") + outf.write(f"{p}\n") logger.debug(f"Graphemes in training data: {sorted(self.g2p_training_graphemes)}") logger.debug(f"Phones in training data: {sorted(self.g2p_training_phones)}") if self.evaluation_mode: for word, pronunciations in self.g2p_validation_dictionary.items(): + word = unicodedata.normalize("NFKD", word) self.g2p_validation_graphemes.update(word) for p in pronunciations: self.g2p_validation_phones.update(p.split()) diff --git a/montreal_forced_aligner/models.py b/montreal_forced_aligner/models.py index 01014b3c..3e7b22ad 100644 --- a/montreal_forced_aligner/models.py +++ b/montreal_forced_aligner/models.py @@ -981,12 +981,18 @@ def rewriter(self): graphemes=self.meta["graphemes"], sequence_separator=self.meta["sequence_separator"], strict=True, + unicode_decomposition=self.meta["unicode_decomposition"], ) else: from montreal_forced_aligner.g2p.generator import Rewriter rewriter = Rewriter( - self.fst, self.grapheme_table, self.phone_table, num_pronunciations=1, strict=True + self.fst, + self.grapheme_table, + self.phone_table, + num_pronunciations=1, + strict=True, + unicode_decomposition=self.meta["unicode_decomposition"], ) return rewriter @@ -1024,6 +1030,7 @@ def meta(self) -> dict: self._meta["graphemes"] = set(self._meta.get("graphemes", [])) self._meta["evaluation"] = self._meta.get("evaluation", []) self._meta["training"] = self._meta.get("training", []) + self._meta["unicode_decomposition"] = self._meta.get("unicode_decomposition", False) return self._meta @property diff --git a/montreal_forced_aligner/tokenization/simple.py b/montreal_forced_aligner/tokenization/simple.py index cbb0bca5..5e2680a8 100644 --- a/montreal_forced_aligner/tokenization/simple.py +++ b/montreal_forced_aligner/tokenization/simple.py @@ -277,8 +277,7 @@ def parse_graphemes( yield word break else: - characters = list(item) - for c in characters: + for c in item: if self.grapheme_set is not None and c in self.grapheme_set: yield c else: diff --git a/montreal_forced_aligner/transcription/models.py b/montreal_forced_aligner/transcription/models.py index 471c63ec..4c9aaea9 100644 --- a/montreal_forced_aligner/transcription/models.py +++ b/montreal_forced_aligner/transcription/models.py @@ -7,6 +7,8 @@ import numpy as np +from montreal_forced_aligner.data import Language + try: with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -38,6 +40,8 @@ def __init__( suppress_numerals: bool = False, **kwargs, ): + self.preset_language = None + self.tokenizer = None super().__init__( model, vad, @@ -52,7 +56,17 @@ def __init__( ) self.base_suppress_tokens = self.options.suppress_tokens if self.preset_language is not None: - self.load_tokenizer(task="transcribe", language=self.preset_language) + self.load_tokenizer(language=self.preset_language) + + def set_language(self, language: typing.Union[str, Language]): + if isinstance(language, str): + language = Language[language] + language = language.value + if self.preset_language != language: + self.preset_language = language + self.tokenizer = None + if self.preset_language is not None: + self.load_tokenizer(language=self.preset_language) def get_suppressed_tokens( self, @@ -62,7 +76,10 @@ def get_suppressed_tokens( alpha_pattern = re.compile(r"\w", flags=re.UNICODE) roman_numeral_pattern = re.compile(r"^(x+(vi+|i+|i?v|x+))$", flags=re.IGNORECASE) - case_roman_numeral_pattern = re.compile(r"(^[IXV]{2,}$|^[xvi]+i$|^x{2,}$|\d)") + case_roman_numeral_pattern = re.compile(r"(^[IXV]{2,}$|^[xv]+i{2,}$|^x{2,}iv$|\d)") + abbreviations_pattern = re.compile( + r"^(sr|sra|mr|dr|mrs|vds|vd|etc)\.?$", flags=re.IGNORECASE + ) def _should_suppress(t): if t.startswith("<|"): @@ -72,6 +89,7 @@ def _should_suppress(t): if ( roman_numeral_pattern.search(t) or case_roman_numeral_pattern.search(t) + or abbreviations_pattern.match(t) or re.match(r"^[XV]$", t) or not alpha_pattern.search(t) ): @@ -86,9 +104,12 @@ def _should_suppress(t): suppressed.append(token_id) return suppressed - def load_tokenizer(self, task, language): + def load_tokenizer(self, language): self.tokenizer = faster_whisper.tokenizer.Tokenizer( - self.model.hf_tokenizer, self.model.model.is_multilingual, task=task, language=language + self.model.hf_tokenizer, + self.model.model.is_multilingual, + task="transcribe", + language=language, ) if self.suppress_numerals: numeral_symbol_tokens = self.get_suppressed_tokens() @@ -102,8 +123,17 @@ def transcribe( utterance_ids, batch_size=None, num_workers=0, - vad_segments=None, ): + if self.preset_language is None: + max_len = 0 + audio = None + for a in audio_batch: + if a["inputs"].shape[-1] > max_len: + max_len = a["inputs"].shape[-1] + audio = a["inputs"] + language = self.detect_language(audio) + self.load_tokenizer(language=language) + utterances = {} batch_size = batch_size or self._batch_size for idx, out in enumerate( @@ -122,7 +152,8 @@ def transcribe( "end": round(audio_batch[idx]["end"], 3), } ) - + if self.preset_language is None: + self.tokenizer = None return utterances diff --git a/montreal_forced_aligner/transcription/multiprocessing.py b/montreal_forced_aligner/transcription/multiprocessing.py index adc169da..2d49da10 100644 --- a/montreal_forced_aligner/transcription/multiprocessing.py +++ b/montreal_forced_aligner/transcription/multiprocessing.py @@ -59,8 +59,6 @@ "speechbrain.lobes.models.huggingface_transformers.huggingface" ) transformers_logger.setLevel(logging.ERROR) - transformers_logger = logging.getLogger("kenlm") - transformers_logger.setLevel(logging.ERROR) import torch try: @@ -653,7 +651,6 @@ def _run(self) -> None: ), run_opts=run_opts, ) - return_q = queue.Queue(2) finished_adding = threading.Event() stopped = threading.Event() diff --git a/montreal_forced_aligner/transcription/transcriber.py b/montreal_forced_aligner/transcription/transcriber.py index f5f43f34..bbc969b0 100644 --- a/montreal_forced_aligner/transcription/transcriber.py +++ b/montreal_forced_aligner/transcription/transcriber.py @@ -1611,17 +1611,43 @@ def setup(self) -> None: model_key = f"speechbrain/asr-{self.architecture}-commonvoice-14-{common_voice_code}" try: with warnings.catch_warnings(): + from speechbrain.utils.fetching import fetch + warnings.simplefilter("ignore") if self.architecture == "wav2vec2": - # Download models if needed + hparam_path = os.path.join( + config.TEMPORARY_DIRECTORY, + "models", + "EncoderASR", + model_key, + "hyperparams.yaml", + ) + hf_cache_path = os.path.join(config.TEMPORARY_DIRECTORY, "models", "hf_cache") + if not os.path.exists(hparam_path): + hparams_local_path = fetch( + filename="hyperparams.yaml", + source=model_key, + savedir=os.path.join( + config.TEMPORARY_DIRECTORY, "models", "EncoderASR", model_key + ), + overwrite=False, + huggingface_cache_dir=hf_cache_path, + ) + with mfa_open(hparams_local_path, "r") as f: + data = f.read() + data = data.replace( + "save_path: wav2vec2_checkpoint", + f"save_path: {os.path.join(hf_cache_path, 'wav2vec2_checkpoint')}", + ) + data = data.replace("kenlm_model_path:", "# kenlm_model_path:") + with mfa_open(hparams_local_path, "w") as f: + f.write(data) m = EncoderASR.from_hparams( source=model_key, savedir=os.path.join( config.TEMPORARY_DIRECTORY, "models", "EncoderASR", model_key ), - huggingface_cache_dir=os.path.join( - config.TEMPORARY_DIRECTORY, "models", "hf_cache" - ), + huggingface_cache_dir=hf_cache_path, ) else: # Download models if needed @@ -1642,6 +1668,7 @@ def setup(self) -> None: except ImportError: raise except Exception: + raise raise ModelError( f"Could not download a speechbrain model with {self.architecture} and {self.language.name} ({model_key})" ) diff --git a/montreal_forced_aligner/utils.py b/montreal_forced_aligner/utils.py index 048719e1..7d975a6f 100644 --- a/montreal_forced_aligner/utils.py +++ b/montreal_forced_aligner/utils.py @@ -17,6 +17,7 @@ import threading import time import typing +import unicodedata from contextlib import contextmanager from pathlib import Path from typing import Any, Dict, List @@ -191,6 +192,7 @@ def parse_dictionary_file( f'Error parsing line {i} of {path}: "{line}" did not have a pronunciation' ) word = line.pop(0) + word = unicodedata.normalize("NFKC", word) prob = None silence_after_prob = None silence_before_correct = None diff --git a/montreal_forced_aligner/vad/models.py b/montreal_forced_aligner/vad/models.py index 25bee8b6..9dd0d885 100644 --- a/montreal_forced_aligner/vad/models.py +++ b/montreal_forced_aligner/vad/models.py @@ -12,6 +12,7 @@ from kalpy.data import Segment from montreal_forced_aligner import config +from montreal_forced_aligner.data import CtmInterval if typing.TYPE_CHECKING: from montreal_forced_aligner.abc import MetaDict @@ -40,11 +41,100 @@ logger = logging.getLogger("mfa") +def get_initial_segmentation(frames: np.ndarray, frame_shift: float) -> typing.List[CtmInterval]: + """ + Compute initial segmentation over voice activity + + Parameters + ---------- + frames: list[Union[int, str]] + List of frames with VAD output + frame_shift: float + Frame shift of features in seconds + + Returns + ------- + List[CtmInterval] + Initial segmentation + """ + segments = [] + cur_segment = None + silent_frames = 0 + non_silent_frames = 0 + for i in range(frames.shape[0]): + f = frames[i] + if int(f) > 0: + non_silent_frames += 1 + if cur_segment is None: + cur_segment = CtmInterval(begin=i * frame_shift, end=0, label="speech") + else: + silent_frames += 1 + if cur_segment is not None: + cur_segment.end = (i - 1) * frame_shift + segments.append(cur_segment) + cur_segment = None + if cur_segment is not None: + cur_segment.end = len(frames) * frame_shift + segments.append(cur_segment) + return segments + + +def merge_segments( + segments: typing.List[CtmInterval], + min_pause_duration: float, + max_segment_length: float, + min_segment_length: float, + snap_boundaries: bool = True, +) -> typing.List[CtmInterval]: + """ + Merge segments together + + Parameters + ---------- + segments: SegmentationType + Initial segments + min_pause_duration: float + Minimum amount of silence time to mark an utterance boundary + max_segment_length: float + Maximum length of segments before they're broken up + min_segment_length: float + Minimum length of segments returned + + Returns + ------- + List[CtmInterval] + Merged segments + """ + merged_segments = [] + snap_boundary_threshold = 0 + if snap_boundaries: + snap_boundary_threshold = min_pause_duration / 2 + for s in segments: + if ( + not merged_segments + or s.begin > merged_segments[-1].end + min_pause_duration + or s.end - merged_segments[-1].begin > max_segment_length + ): + if merged_segments and snap_boundary_threshold: + boundary_gap = s.begin - merged_segments[-1].end + if boundary_gap < snap_boundary_threshold: + half_boundary = boundary_gap / 2 + else: + half_boundary = snap_boundary_threshold / 2 + merged_segments[-1].end += half_boundary + s.begin -= half_boundary + + merged_segments.append(s) + else: + merged_segments[-1].end = s.end + return [x for x in merged_segments if x.end - x.begin > min_segment_length] + + class MfaVAD(VAD): def energy_VAD( self, audio_file: typing.Union[str, Path, np.ndarray, torch.Tensor], - boundaries, + segments, activation_th=0.5, deactivation_th=0.0, eps=1e-6, @@ -65,7 +155,7 @@ def energy_VAD( audio_file: path Path of the audio file containing the recording. The file is read with torchaudio. - boundaries: torch.Tensor + segments: list[CtmInterval] torch.Tensor containing the speech boundaries. It can be derived using the get_boundaries method. activation_th: float @@ -93,26 +183,26 @@ def energy_VAD( # Computing the chunk length of the energy window chunk_len = int(self.time_resolution * sample_rate) - new_boundaries = [] + new_segments = [] # Processing speech segments - for i in range(boundaries.shape[0]): - begin_sample = int(boundaries[i, 0] * sample_rate) - end_sample = int(boundaries[i, 1] * sample_rate) + for segment in segments: + begin_sample = int(segment.begin * sample_rate) + end_sample = int(segment.end * sample_rate) seg_len = end_sample - begin_sample if seg_len < chunk_len: continue if not isinstance(audio_file, torch.Tensor): # Reading the speech segment - segment, _ = torchaudio.load( + audio, _ = torchaudio.load( audio_file, frame_offset=begin_sample, num_frames=seg_len ) else: - segment = audio_file[:, begin_sample : begin_sample + seg_len] + audio = audio_file[:, begin_sample : begin_sample + seg_len] # Create chunks segment_chunks = self.create_chunks( - segment, chunk_size=chunk_len, chunk_stride=chunk_len + audio, chunk_size=chunk_len, chunk_stride=chunk_len ) # Energy computation within each chunk @@ -123,27 +213,19 @@ def energy_VAD( energy_chunks = ( (energy_chunks - energy_chunks.mean()) / (2 * energy_chunks.std()) ) + 0.5 - energy_chunks = energy_chunks.unsqueeze(0).unsqueeze(2) + energy_chunks = energy_chunks # Apply threshold based on the energy value - energy_vad = self.apply_threshold( - energy_chunks, - activation_th=activation_th, - deactivation_th=deactivation_th, + new_segments.extend( + self.generate_segments( + energy_chunks, + activation_th=activation_th, + deactivation_th=deactivation_th, + begin=segment.begin, + end=segment.end, + ) ) - - # Get the boundaries - energy_boundaries = self.get_boundaries(energy_vad, output_value="seconds") - - # Get the final boundaries in the original signal - for j in range(energy_boundaries.shape[0]): - start_en = boundaries[i, 0] + energy_boundaries[j, 0] - end_end = boundaries[i, 0] + energy_boundaries[j, 1] - new_boundaries.append([start_en, end_end]) - - # Convert boundaries to tensor - new_boundaries = torch.FloatTensor(new_boundaries).to(boundaries.device) - return new_boundaries + return new_segments def double_check_speech_segments(self, boundaries, audio_file, speech_th=0.5): """Takes in input the boundaries of the detected speech segments and @@ -204,15 +286,13 @@ def segment_utterance( self, segment: typing.Union[Segment, np.ndarray], apply_energy_VAD: bool = False, - double_check: bool = False, close_th: float = 0.333, len_th: float = 0.333, activation_th: float = 0.5, deactivation_th: float = 0.25, en_activation_th: float = 0.5, en_deactivation_th: float = 0.4, - speech_th: float = 0.5, - allow_empty: bool = True, + **kwargs, ) -> typing.List[Segment]: if isinstance(segment, Segment): y = torch.tensor(segment.wave[np.newaxis, :]) @@ -223,70 +303,159 @@ def segment_utterance( y = torch.tensor(segment) else: y = segment - prob_chunks = self.get_speech_prob_chunk(y).float() - prob_th = self.apply_threshold( + prob_chunks = self.get_speech_prob_chunk(y).float().cpu().numpy()[0, ...] + # Compute the boundaries of the speech segments + segments = self.generate_segments( prob_chunks, activation_th=activation_th, deactivation_th=deactivation_th, - ).float() + begin=segment.begin + if isinstance(segment, Segment) and segment.begin is not None + else None, + end=segment.end if isinstance(segment, Segment) and segment.end is not None else None, + ) - # Compute the boundaries of the speech segments - boundaries = self.get_boundaries(prob_th, output_value="seconds").cpu() - if isinstance(segment, Segment) and segment.begin is not None: - boundaries += segment.begin # Apply energy-based VAD on the detected speech segments if apply_energy_VAD: - vad_boundaries = self.energy_VAD( + segments = self.energy_VAD( y, - boundaries, + segments, activation_th=en_activation_th, deactivation_th=en_deactivation_th, ) - if vad_boundaries.size(0) != 0 or allow_empty: - boundaries = vad_boundaries # Merge short segments - boundaries = self.merge_close_segments(boundaries, close_th=close_th) - - # Remove short segments - filtered_boundaries = self.remove_short_segments(boundaries, len_th=len_th) - if filtered_boundaries.size(0) != 0 or allow_empty: - boundaries = filtered_boundaries + segments = merge_segments( + segments, + min_pause_duration=close_th, + max_segment_length=30, + min_segment_length=len_th, + snap_boundaries=False, + ) - # Double check speech segments - if double_check: - checked_boundaries = self.double_check_speech_segments( - boundaries, y, speech_th=speech_th - ) - if checked_boundaries.size(0) != 0 or allow_empty: - boundaries = checked_boundaries - boundaries[:, 0] -= round(close_th / 2, 3) - boundaries[:, 1] += round(close_th / 2, 3) - segments = [] - for i in range(boundaries.numpy().shape[0]): - begin, end = boundaries[i] + # Padding + for i, s in enumerate(segments): + begin, end = s.begin, s.end + begin -= close_th / 2 + end += close_th / 2 if i == 0: begin = max(begin, 0) - if i == boundaries.numpy().shape[0] - 1: + if i == len(segments) - 1: end = min( end, - segment.end - if isinstance(segment, Segment) - else segment.shape[0] / self.sample_rate, + segment.shape[0] / self.sample_rate + if not isinstance(segment, Segment) + else segment.end, ) - seg = Segment( - segment.file_path if isinstance(segment, Segment) else "", - float(begin), - float(end), - segment.channel if isinstance(segment, Segment) else 0, - ) - segments.append(seg) + s.begin = begin + s.end = end + if isinstance(segment, Segment): + segments[i] = Segment(segment.file_path, s.begin, s.end, segment.channel) return segments + def generate_segments( + self, vad_prob, activation_th=0.5, deactivation_th=0.25, begin=None, end=None + ): + """Scans the frame-level speech probabilities and applies a threshold + on them. Speech starts when a value larger than activation_th is + detected, while it ends when observing a value lower than + the deactivation_th. + + Arguments + --------- + vad_prob: numpy.ndarray + Frame-level speech probabilities. + activation_th: float + Threshold for starting a speech segment. + deactivation_th: float + Threshold for ending a speech segment. + + Returns + ------- + vad_th: torch.Tensor + torch.Tensor containing 1 for speech regions and 0 for non-speech regions. + """ + if begin is None: + begin = 0 + # Loop over batches and time steps + is_active = vad_prob[0] > activation_th + start = 0 + boundaries = [] + for time_step in range(1, vad_prob.shape[0] - 1): + y = vad_prob[time_step] + if is_active: + if y < deactivation_th: + e = self.time_resolution * (time_step - 1) + boundaries.append( + CtmInterval(begin=start + begin, end=e + begin, label="speech") + ) + is_active = False + elif y > activation_th: + is_active = True + start = self.time_resolution * time_step + if is_active: + if end is not None: + e = end + else: + e = self.time_resolution * vad_prob.shape[0] + e += begin + boundaries.append(CtmInterval(begin=start + begin, end=e, label="speech")) + return boundaries + + def get_speech_prob_chunk(self, wavs, wav_lens=None): + """Outputs the frame-level posterior probability for the input audio chunks + Outputs close to zero refers to time steps with a low probability of speech + activity, while outputs closer to one likely contain speech. + + Arguments + --------- + wavs : torch.Tensor + Batch of waveforms [batch, time, channels] or [batch, time] + depending on the model. Make sure the sample rate is fs=16000 Hz. + wav_lens : torch.Tensor + Lengths of the waveforms relative to the longest one in the + batch, tensor of shape [batch]. The longest one should have + relative length 1.0 and others len(waveform) / max_length. + Used for ignoring padding. + + Returns + ------- + torch.Tensor + The encoded batch + """ + # Manage single waveforms in input + if len(wavs.shape) == 1: + wavs = wavs.unsqueeze(0) + + # Assign full length if wav_lens is not assigned + if wav_lens is None: + wav_lens = torch.ones(wavs.shape[0], device=self.device) + + # Storing waveform in the specified device + wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device) + wavs = wavs.float() + + # Computing features and embeddings + feats = self.mods.compute_features(wavs) + feats = self.mods.mean_var_norm(feats, wav_lens) + outputs = self.mods.cnn(feats) + + outputs = outputs.reshape( + outputs.shape[0], + outputs.shape[1], + outputs.shape[2] * outputs.shape[3], + ) + + outputs, h = self.mods.rnn(outputs) + outputs = self.mods.dnn(outputs) + output_prob = torch.sigmoid(outputs) + + return output_prob + def segment_for_whisper( self, segment: typing.Union[torch.Tensor, np.ndarray], - apply_energy_VAD: bool = False, + apply_energy_VAD: bool = True, close_th: float = 0.333, len_th: float = 0.333, activation_th: float = 0.5, @@ -295,60 +464,42 @@ def segment_for_whisper( en_deactivation_th: float = 0.4, **kwargs, ) -> typing.List[typing.Dict[str, float]]: - if len(segment.shape) == 1: - y = torch.tensor(segment[np.newaxis, :]) - elif not torch.is_tensor(segment): - y = torch.tensor(segment) + if isinstance(segment, Segment): + y = torch.tensor(segment.wave[np.newaxis, :]) else: - y = segment - prob_chunks = self.get_speech_prob_chunk(y).float() - prob_th = self.apply_threshold( - prob_chunks, + if len(segment.shape) == 1: + y = torch.tensor(segment[np.newaxis, :]) + elif not torch.is_tensor(segment): + y = torch.tensor(segment) + else: + y = segment + segments = self.segment_utterance( + segment, + apply_energy_VAD=apply_energy_VAD, + close_th=close_th, + len_th=len_th, activation_th=activation_th, deactivation_th=deactivation_th, - ).float() - - # Compute the boundaries of the speech segments - boundaries = self.get_boundaries(prob_th, output_value="seconds").cpu() - del prob_chunks - del prob_th + en_deactivation_th=en_deactivation_th, + **kwargs, + ) - # Apply energy-based VAD on the detected speech segments - if apply_energy_VAD: - vad_boundaries = self.energy_VAD( - y, - boundaries, - activation_th=en_activation_th, - deactivation_th=en_deactivation_th, + # Padding + segments_for_whisper = [] + for i, s in enumerate(segments): + begin, end = s.begin, s.end + f1 = int(round(begin, 3) * self.sample_rate) + f2 = int(round(end, 3) * self.sample_rate) + segments_for_whisper.append( + {"start": float(begin), "end": float(end), "inputs": y[0, f1:f2]} ) - boundaries = vad_boundaries - - # Merge short segments - boundaries = self.merge_close_segments(boundaries, close_th=close_th) - - # Remove short segments - filtered_boundaries = self.remove_short_segments(boundaries, len_th=len_th) - if filtered_boundaries.size(0) != 0: - boundaries = filtered_boundaries - boundaries[:, 0] -= round(close_th / 2, 3) - boundaries[:, 1] += round(close_th / 2, 3) - segments = [] - for i in range(boundaries.numpy().shape[0]): - begin, end = boundaries[i] - if i == 0: - begin = max(begin, 0) - if i == boundaries.numpy().shape[0] - 1: - end = min(end, segment.shape[0] / self.sample_rate) - f1 = int(float(begin) * self.sample_rate) - f2 = int(float(end) * self.sample_rate) - segments.append({"start": float(begin), "end": float(end), "inputs": y[0, f1:f2]}) - return segments + return segments_for_whisper class SpeechbrainSegmenterMixin: def __init__( self, - apply_energy_vad: bool = False, + apply_energy_vad: bool = True, double_check: bool = False, close_th: float = 0.333, len_th: float = 0.333, diff --git a/montreal_forced_aligner/vad/multiprocessing.py b/montreal_forced_aligner/vad/multiprocessing.py index 1037c7da..0177e2f3 100644 --- a/montreal_forced_aligner/vad/multiprocessing.py +++ b/montreal_forced_aligner/vad/multiprocessing.py @@ -3,9 +3,8 @@ import typing from pathlib import Path -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, Union -import numpy import pynini import pywrapfst from _kalpy.decoder import LatticeFasterDecoder, LatticeFasterDecoderConfig @@ -24,11 +23,11 @@ from sqlalchemy.orm import joinedload, subqueryload from montreal_forced_aligner.abc import KaldiFunction -from montreal_forced_aligner.data import CtmInterval, MfaArguments +from montreal_forced_aligner.data import MfaArguments from montreal_forced_aligner.db import File, Job, Speaker, Utterance from montreal_forced_aligner.exceptions import SegmenterError from montreal_forced_aligner.models import AcousticModel, G2PModel -from montreal_forced_aligner.vad.models import MfaVAD +from montreal_forced_aligner.vad.models import MfaVAD, get_initial_segmentation, merge_segments if TYPE_CHECKING: SpeakerCharacterType = Union[str, int] @@ -43,8 +42,6 @@ "SegmentVadArguments", "SegmentTranscriptFunction", "SegmentVadFunction", - "get_initial_segmentation", - "merge_segments", "segment_utterance_transcript", "segment_utterance_vad", ] @@ -324,92 +321,6 @@ def align_interjection_words( return " ".join(original_transcript.split()[len(interjections_removed.split()) :]) -def get_initial_segmentation(frames: numpy.ndarray, frame_shift: float) -> List[CtmInterval]: - """ - Compute initial segmentation over voice activity - - Parameters - ---------- - frames: list[Union[int, str]] - List of frames with VAD output - frame_shift: float - Frame shift of features in seconds - - Returns - ------- - List[CtmInterval] - Initial segmentation - """ - segments = [] - cur_segment = None - silent_frames = 0 - non_silent_frames = 0 - for i in range(frames.shape[0]): - f = frames[i] - if int(f) > 0: - non_silent_frames += 1 - if cur_segment is None: - cur_segment = CtmInterval(begin=i * frame_shift, end=0, label="speech") - else: - silent_frames += 1 - if cur_segment is not None: - cur_segment.end = (i - 1) * frame_shift - segments.append(cur_segment) - cur_segment = None - if cur_segment is not None: - cur_segment.end = len(frames) * frame_shift - segments.append(cur_segment) - return segments - - -def merge_segments( - segments: List[CtmInterval], - min_pause_duration: float, - max_segment_length: float, - min_segment_length: float, -) -> List[CtmInterval]: - """ - Merge segments together - - Parameters - ---------- - segments: SegmentationType - Initial segments - min_pause_duration: float - Minimum amount of silence time to mark an utterance boundary - max_segment_length: float - Maximum length of segments before they're broken up - min_segment_length: float - Minimum length of segments returned - - Returns - ------- - List[CtmInterval] - Merged segments - """ - merged_segments = [] - snap_boundary_threshold = min_pause_duration / 2 - for s in segments: - if ( - not merged_segments - or s.begin > merged_segments[-1].end + min_pause_duration - or s.end - merged_segments[-1].begin > max_segment_length - ): - if merged_segments and snap_boundary_threshold: - boundary_gap = s.begin - merged_segments[-1].end - if boundary_gap < snap_boundary_threshold: - half_boundary = boundary_gap / 2 - else: - half_boundary = snap_boundary_threshold / 2 - merged_segments[-1].end += half_boundary - s.begin -= half_boundary - - merged_segments.append(s) - else: - merged_segments[-1].end = s.end - return [x for x in merged_segments if x.end - x.begin > min_segment_length] - - def segment_utterance_vad( segment: Segment, mfcc_options: MetaDict, diff --git a/montreal_forced_aligner/vad/segmenter.py b/montreal_forced_aligner/vad/segmenter.py index 27cb0f06..06ef9ff9 100644 --- a/montreal_forced_aligner/vad/segmenter.py +++ b/montreal_forced_aligner/vad/segmenter.py @@ -218,23 +218,15 @@ def segment_vad_speechbrain(self) -> None: .join(Utterance.file) ) for f, u in files: - boundaries = self.vad_model.get_speech_segments( - str(f.sound_file.sound_file_path), **kwargs - ).numpy() - for i in range(boundaries.shape[0]): + audio = f.sound_file.load_audio() + segments = self.vad_model.segment_utterance(audio, **kwargs) + for seg in segments: old_utts.add(u.id) - begin, end = boundaries[i, :] - begin -= round(self.close_th / 2, 3) - end += round(self.close_th / 2, 3) - if i == 0: - begin = max(0.0, begin) - if i == boundaries.shape[0] - 1: - end = min(f.sound_file.duration, end) new_utts.append( { "id": utt_index, - "begin": begin, - "end": end, + "begin": seg.begin, + "end": seg.end, "text": "speech", "speaker_id": u.speaker_id, "file_id": u.file_id, diff --git a/tests/test_commandline_align.py b/tests/test_commandline_align.py index b06bf7a8..5fb7b211 100644 --- a/tests/test_commandline_align.py +++ b/tests/test_commandline_align.py @@ -672,8 +672,6 @@ def test_swedish_cv( swedish_cv_dictionary, swedish_cv_acoustic_model, output_dir, - "--language", - "swedish", "--config_path", basic_align_config_path, "-q", @@ -713,18 +711,18 @@ def test_swedish_cv( def test_swedish_mfa( swedish_dir, generated_dir, - swedish_cv_dictionary, + swedish_mfa_dictionary, temp_dir, basic_align_config_path, - swedish_cv_acoustic_model, + swedish_mfa_acoustic_model, db_setup, ): output_dir = generated_dir.joinpath("swedish_mfa_output") command = [ "align", swedish_dir, - swedish_cv_dictionary, - swedish_cv_acoustic_model, + swedish_mfa_dictionary, + swedish_mfa_acoustic_model, output_dir, "--config_path", basic_align_config_path, diff --git a/tests/test_commandline_transcribe.py b/tests/test_commandline_transcribe.py index 88aecb06..744d012c 100644 --- a/tests/test_commandline_transcribe.py +++ b/tests/test_commandline_transcribe.py @@ -134,7 +134,6 @@ def test_transcribe_arpa( transcribe_config_path, db_setup, ): - temp_dir = os.path.join(temp_dir, "arpa_test_temp") output_path = generated_dir.joinpath("transcribe_test_arpa") command = [ "transcribe", diff --git a/tests/test_corpus.py b/tests/test_corpus.py index 118aa7fe..a5e3dd01 100644 --- a/tests/test_corpus.py +++ b/tests/test_corpus.py @@ -1,5 +1,6 @@ import os import shutil +import unicodedata from montreal_forced_aligner import config from montreal_forced_aligner.corpus.acoustic_corpus import ( @@ -596,8 +597,8 @@ def test_japanese(japanese_dir, japanese_dict_path, generated_dir, db_setup): print(corpus.utterances()) punctuated = corpus.get_utterances(file="日本語")[0] - assert punctuated.text == "「はい」、。! 『何 でしょう』" - assert punctuated.normalized_text == "はい 何 でしょう" + assert punctuated.text == unicodedata.normalize("NFKC", "「はい」、。! 『何 でしょう』") + assert punctuated.normalized_text == unicodedata.normalize("NFKC", "はい 何 でしょう") corpus.cleanup_connections() @@ -613,8 +614,8 @@ def test_devanagari(devanagari_dir, hindi_dict_path, generated_dir, db_setup): print(corpus.utterances()) punctuated = corpus.get_utterances(file="devanagari")[0] - assert punctuated.text == "हैंः हूं हौंसला" - assert punctuated.normalized_text == "हैंः हूं हौंसला" + assert punctuated.text == unicodedata.normalize("NFKC", "हैंः हूं हौंसला") + assert punctuated.normalized_text == unicodedata.normalize("NFKC", "हैंः हूं हौंसला") corpus.cleanup_connections() @@ -630,12 +631,12 @@ def test_french_clitics(french_clitics_dir, frclitics_dict_path, generated_dir, corpus.load_corpus() punctuated = corpus.get_utterances(file="french_clitics")[0] - assert ( - punctuated.text - == "aujourd aujourd'hui m'appelle purple-people-eater vingt-six m'm'appelle c'est m'c'est m'appele m'ving-sic flying'purple-people-eater" + assert punctuated.text == unicodedata.normalize( + "NFKC", + "aujourd aujourd'hui m'appelle purple-people-eater vingt-six m'm'appelle c'est m'c'est m'appele m'ving-sic flying'purple-people-eater", ) - assert ( - punctuated.normalized_text - == "aujourd aujourd'hui m' appelle purple-people-eater vingt six m' m' appelle c'est m' c'est m' appele m' ving sic flying'purple-people-eater" + assert punctuated.normalized_text == unicodedata.normalize( + "NFKC", + "aujourd aujourd'hui m' appelle purple-people-eater vingt six m' m' appelle c'est m' c'est m' appele m' ving sic flying'purple-people-eater", ) corpus.cleanup_connections() From 4ff7cd985cb5bfebba427295d6eb02e76c18fadc Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Thu, 19 Sep 2024 19:01:58 -0700 Subject: [PATCH 06/16] Fix whisperx import error --- montreal_forced_aligner/transcription/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/montreal_forced_aligner/transcription/models.py b/montreal_forced_aligner/transcription/models.py index 4c9aaea9..53e6186d 100644 --- a/montreal_forced_aligner/transcription/models.py +++ b/montreal_forced_aligner/transcription/models.py @@ -18,7 +18,7 @@ FOUND_WHISPERX = True -except ImportError: +except (ImportError, OSError): FasterWhisperPipeline = object FOUND_WHISPERX = False From 88ea03391b1dca1fab1bfffc34f360a909f68809 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Fri, 20 Sep 2024 12:13:18 -0700 Subject: [PATCH 07/16] Fix for segment test failures --- montreal_forced_aligner/vad/models.py | 161 +++++---- .../vad/multiprocessing.py | 14 +- montreal_forced_aligner/vad/segmenter.py | 306 ++++++++++++++---- 3 files changed, 342 insertions(+), 139 deletions(-) diff --git a/montreal_forced_aligner/vad/models.py b/montreal_forced_aligner/vad/models.py index 9dd0d885..fca35473 100644 --- a/montreal_forced_aligner/vad/models.py +++ b/montreal_forced_aligner/vad/models.py @@ -135,8 +135,8 @@ def energy_VAD( self, audio_file: typing.Union[str, Path, np.ndarray, torch.Tensor], segments, - activation_th=0.5, - deactivation_th=0.0, + activation_threshold=0.5, + deactivation_threshold=0.0, eps=1e-6, ): """Applies energy-based VAD within the detected speech segments.The neural @@ -158,9 +158,9 @@ def energy_VAD( segments: list[CtmInterval] torch.Tensor containing the speech boundaries. It can be derived using the get_boundaries method. - activation_th: float + activation_threshold: float A new speech segment is started it the energy is above activation_th. - deactivation_th: float + deactivation_threshold: float The segment is considered ended when the energy is <= deactivation_th. eps: float Small constant for numerical stability. @@ -219,8 +219,8 @@ def energy_VAD( new_segments.extend( self.generate_segments( energy_chunks, - activation_th=activation_th, - deactivation_th=deactivation_th, + activation_threshold=activation_threshold, + deactivation_threshold=deactivation_threshold, begin=segment.begin, end=segment.end, ) @@ -285,13 +285,14 @@ def double_check_speech_segments(self, boundaries, audio_file, speech_th=0.5): def segment_utterance( self, segment: typing.Union[Segment, np.ndarray], - apply_energy_VAD: bool = False, - close_th: float = 0.333, - len_th: float = 0.333, - activation_th: float = 0.5, - deactivation_th: float = 0.25, - en_activation_th: float = 0.5, - en_deactivation_th: float = 0.4, + apply_energy_vad: bool = False, + min_pause_duration: float = 0.333, + max_segment_length: float = 30, + min_segment_length: float = 0.333, + activation_threshold: float = 0.5, + deactivation_threshold: float = 0.25, + energy_activation_threshold: float = 0.5, + energy_deactivation_threshold: float = 0.4, **kwargs, ) -> typing.List[Segment]: if isinstance(segment, Segment): @@ -307,8 +308,8 @@ def segment_utterance( # Compute the boundaries of the speech segments segments = self.generate_segments( prob_chunks, - activation_th=activation_th, - deactivation_th=deactivation_th, + activation_threshold=activation_threshold, + deactivation_threshold=deactivation_threshold, begin=segment.begin if isinstance(segment, Segment) and segment.begin is not None else None, @@ -316,28 +317,28 @@ def segment_utterance( ) # Apply energy-based VAD on the detected speech segments - if apply_energy_VAD: + if apply_energy_vad: segments = self.energy_VAD( y, segments, - activation_th=en_activation_th, - deactivation_th=en_deactivation_th, + activation_threshold=energy_activation_threshold, + deactivation_threshold=energy_deactivation_threshold, ) # Merge short segments segments = merge_segments( segments, - min_pause_duration=close_th, - max_segment_length=30, - min_segment_length=len_th, + min_pause_duration=min_pause_duration, + max_segment_length=max_segment_length, + min_segment_length=min_segment_length, snap_boundaries=False, ) # Padding for i, s in enumerate(segments): begin, end = s.begin, s.end - begin -= close_th / 2 - end += close_th / 2 + begin -= min_pause_duration / 2 + end += min_pause_duration / 2 if i == 0: begin = max(begin, 0) if i == len(segments) - 1: @@ -354,7 +355,7 @@ def segment_utterance( return segments def generate_segments( - self, vad_prob, activation_th=0.5, deactivation_th=0.25, begin=None, end=None + self, vad_prob, activation_threshold=0.5, deactivation_threshold=0.25, begin=None, end=None ): """Scans the frame-level speech probabilities and applies a threshold on them. Speech starts when a value larger than activation_th is @@ -365,9 +366,9 @@ def generate_segments( --------- vad_prob: numpy.ndarray Frame-level speech probabilities. - activation_th: float + activation_threshold: float Threshold for starting a speech segment. - deactivation_th: float + deactivation_threshold: float Threshold for ending a speech segment. Returns @@ -378,19 +379,19 @@ def generate_segments( if begin is None: begin = 0 # Loop over batches and time steps - is_active = vad_prob[0] > activation_th + is_active = vad_prob[0] > activation_threshold start = 0 boundaries = [] for time_step in range(1, vad_prob.shape[0] - 1): y = vad_prob[time_step] if is_active: - if y < deactivation_th: + if y < deactivation_threshold: e = self.time_resolution * (time_step - 1) boundaries.append( CtmInterval(begin=start + begin, end=e + begin, label="speech") ) is_active = False - elif y > activation_th: + elif y > activation_threshold: is_active = True start = self.time_resolution * time_step if is_active: @@ -455,13 +456,14 @@ def get_speech_prob_chunk(self, wavs, wav_lens=None): def segment_for_whisper( self, segment: typing.Union[torch.Tensor, np.ndarray], - apply_energy_VAD: bool = True, - close_th: float = 0.333, - len_th: float = 0.333, - activation_th: float = 0.5, - deactivation_th: float = 0.25, - en_activation_th: float = 0.5, - en_deactivation_th: float = 0.4, + apply_energy_vad: bool = True, + max_segment_length: float = 30, + min_segment_length: float = 0.333, + min_pause_duration: float = 0.333, + activation_threshold: float = 0.5, + deactivation_threshold: float = 0.25, + en_activation_threshold: float = 0.5, + en_deactivation_threshold: float = 0.4, **kwargs, ) -> typing.List[typing.Dict[str, float]]: if isinstance(segment, Segment): @@ -475,12 +477,14 @@ def segment_for_whisper( y = segment segments = self.segment_utterance( segment, - apply_energy_VAD=apply_energy_VAD, - close_th=close_th, - len_th=len_th, - activation_th=activation_th, - deactivation_th=deactivation_th, - en_deactivation_th=en_deactivation_th, + apply_energy_vad=apply_energy_vad, + max_segment_length=max_segment_length, + min_segment_length=min_segment_length, + min_pause_duration=min_pause_duration, + activation_threshold=activation_threshold, + deactivation_threshold=deactivation_threshold, + en_activation_threshold=en_activation_threshold, + en_deactivation_threshold=en_deactivation_threshold, **kwargs, ) @@ -496,18 +500,47 @@ def segment_for_whisper( return segments_for_whisper -class SpeechbrainSegmenterMixin: +class SegmenterMixin: + def __init__( + self, + max_segment_length: float = 30, + min_segment_length: float = 0.333, + min_pause_duration: float = 0.333, + activation_threshold: float = 0.5, + deactivation_threshold: float = 0.25, + energy_activation_threshold: float = 0.5, + energy_deactivation_threshold: float = 0.4, + **kwargs, + ): + self.max_segment_length = max_segment_length + self.min_segment_length = min_segment_length + self.min_pause_duration = min_pause_duration + self.activation_threshold = activation_threshold + self.deactivation_threshold = deactivation_threshold + self.energy_activation_threshold = energy_activation_threshold + self.energy_deactivation_threshold = energy_deactivation_threshold + super().__init__(**kwargs) + + @property + def segmentation_options(self) -> MetaDict: + """Options for segmentation""" + return { + "max_segment_length": self.max_segment_length, + "min_segment_length": self.min_segment_length, + "activation_threshold": self.activation_threshold, + "deactivation_threshold": self.deactivation_threshold, + "energy_activation_threshold": self.energy_activation_threshold, + "energy_deactivation_threshold": self.energy_deactivation_threshold, + "min_pause_duration": self.min_pause_duration, + } + + +class SpeechbrainSegmenterMixin(SegmenterMixin): def __init__( self, apply_energy_vad: bool = True, double_check: bool = False, - close_th: float = 0.333, - len_th: float = 0.333, - activation_th: float = 0.5, - deactivation_th: float = 0.25, - en_activation_th: float = 0.5, - en_deactivation_th: float = 0.4, - speech_th: float = 0.5, + speech_threshold: float = 0.5, cuda: bool = False, **kwargs, ): @@ -519,13 +552,7 @@ def __init__( super().__init__(**kwargs) self.apply_energy_vad = apply_energy_vad self.double_check = double_check - self.close_th = close_th - self.len_th = len_th - self.activation_th = activation_th - self.deactivation_th = deactivation_th - self.en_activation_th = en_activation_th - self.en_deactivation_th = en_deactivation_th - self.speech_th = speech_th + self.speech_threshold = speech_threshold self.cuda = cuda self.speechbrain = True self.vad_model = None @@ -541,14 +568,12 @@ def __init__( @property def segmentation_options(self) -> MetaDict: """Options for segmentation""" - return { - "apply_energy_VAD": self.apply_energy_vad, - "double_check": self.double_check, - "activation_th": self.activation_th, - "deactivation_th": self.deactivation_th, - "en_activation_th": self.en_activation_th, - "en_deactivation_th": self.en_deactivation_th, - "speech_th": self.speech_th, - "close_th": self.close_th, - "len_th": self.len_th, - } + options = super().segmentation_options + options.update( + { + "apply_energy_vad": self.apply_energy_vad, + "double_check": self.double_check, + "speech_threshold": self.speech_threshold, + } + ) + return options diff --git a/montreal_forced_aligner/vad/multiprocessing.py b/montreal_forced_aligner/vad/multiprocessing.py index 0177e2f3..5e45338e 100644 --- a/montreal_forced_aligner/vad/multiprocessing.py +++ b/montreal_forced_aligner/vad/multiprocessing.py @@ -70,7 +70,7 @@ class SegmentTranscriptArguments(MfaArguments): def segment_utterance( segment: Segment, - vad_model: MfaVAD, + vad_model: typing.Optional[MfaVAD], segmentation_options: MetaDict, mfcc_options: MetaDict = None, vad_options: MetaDict = None, @@ -344,9 +344,9 @@ def segment_utterance_vad( segments = get_initial_segmentation(vad, mfcc_computer.frame_shift) segments = merge_segments( segments, - segmentation_options["close_th"], - segmentation_options["large_chunk_size"], - segmentation_options["len_th"] if allow_empty else 0.02, + segmentation_options["min_pause_duration"], + segmentation_options["max_segment_length"], + segmentation_options["min_segment_length"] if allow_empty else 0.02, ) new_segments = [] for s in segments: @@ -397,9 +397,9 @@ def _run(self): merged = merge_segments( initial_segments, - self.segmentation_options["close_th"], - self.segmentation_options["large_chunk_size"], - self.segmentation_options["len_th"], + self.segmentation_options["min_pause_duration"], + self.segmentation_options["max_segment_length"], + self.segmentation_options["min_segment_length"], ) self.callback((int(utt_id.split("-")[-1]), merged)) reader.Next() diff --git a/montreal_forced_aligner/vad/segmenter.py b/montreal_forced_aligner/vad/segmenter.py index 06ef9ff9..3e2b7f1a 100644 --- a/montreal_forced_aligner/vad/segmenter.py +++ b/montreal_forced_aligner/vad/segmenter.py @@ -36,7 +36,7 @@ from montreal_forced_aligner.tokenization.spacy import generate_language_tokenizer from montreal_forced_aligner.transcription.transcriber import TranscriberMixin from montreal_forced_aligner.utils import log_kaldi_errors, run_kaldi_function -from montreal_forced_aligner.vad.models import SpeechbrainSegmenterMixin +from montreal_forced_aligner.vad.models import SegmenterMixin, SpeechbrainSegmenterMixin from montreal_forced_aligner.vad.multiprocessing import ( SegmentTranscriptArguments, SegmentTranscriptFunction, @@ -57,7 +57,7 @@ class VadSegmenter( VadConfigMixin, AcousticCorpusMixin, FileExporterMixin, - SpeechbrainSegmenterMixin, + SegmenterMixin, TopLevelMfaWorker, ): """ @@ -173,10 +173,10 @@ def segment_vad_arguments(self) -> List[SegmentVadArguments]: options = self.segmentation_options options.update( { - "large_chunk_size": self.large_chunk_size, + "large_chunk_size": self.max_segment_length, "frame_shift": getattr(self, "export_frame_shift", 0.01), - "small_chunk_size": self.small_chunk_size, - "overlap_small_chunk": self.overlap_small_chunk, + "small_chunk_size": self.min_segment_length, + "overlap_small_chunk": self.min_pause_duration, } ) return [ @@ -190,7 +190,7 @@ def segment_vad_arguments(self) -> List[SegmentVadArguments]: for j in self.jobs ] - def segment_vad_speechbrain(self) -> None: + def segment_vad(self) -> None: """ Run segmentation based off of VAD. @@ -202,51 +202,238 @@ def segment_vad_speechbrain(self) -> None: Job method for generating arguments for helper function """ + arguments = self.segment_vad_arguments() old_utts = set() new_utts = [] - kwargs = self.segmentation_options - with tqdm( - total=self.num_utterances, disable=config.QUIET - ) as pbar, self.session() as session: - utt_index = session.query(sqlalchemy.func.max(Utterance.id)).scalar() - if not utt_index: - utt_index = 0 - utt_index += 1 - files: List[File] = ( - session.query(File, Utterance) - .options(joinedload(File.sound_file)) - .join(Utterance.file) + + with self.session() as session: + utterances = session.query( + Utterance.id, Utterance.channel, Utterance.speaker_id, Utterance.file_id ) - for f, u in files: - audio = f.sound_file.load_audio() - segments = self.vad_model.segment_utterance(audio, **kwargs) + utterance_cache = {} + for u_id, channel, speaker_id, file_id in utterances: + utterance_cache[u_id] = (channel, speaker_id, file_id) + for utt, segments in run_kaldi_function( + SegmentVadFunction, arguments, total_count=self.num_utterances + ): + old_utts.add(utt) + channel, speaker_id, file_id = utterance_cache[utt] for seg in segments: - old_utts.add(u.id) new_utts.append( { - "id": utt_index, "begin": seg.begin, "end": seg.end, "text": "speech", - "speaker_id": u.speaker_id, - "file_id": u.file_id, + "speaker_id": speaker_id, + "file_id": file_id, "oovs": "", "normalized_text": "", "features": "", "in_subset": False, "ignored": False, - "channel": u.channel, + "channel": channel, } ) - utt_index += 1 - pbar.update(1) session.query(Utterance).filter(Utterance.id.in_(old_utts)).delete() session.bulk_insert_mappings( Utterance, new_utts, return_defaults=False, render_nulls=True ) session.commit() - def segment_vad_mfa(self) -> None: + def setup(self) -> None: + """Setup segmentation""" + super().setup() + self.create_new_current_workflow(WorkflowType.segmentation) + log_dir = self.working_directory.joinpath("log") + os.makedirs(log_dir, exist_ok=True) + try: + self.load_corpus() + except Exception as e: + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs) + e.update_log_file() + raise + + def segment(self) -> None: + """ + Performs VAD and segmentation into utterances + + Raises + ------ + :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` + If there were any errors in running Kaldi binaries + """ + self.setup() + self.create_new_current_workflow(WorkflowType.segmentation) + wf = self.current_workflow + if wf.done: + logger.info("Segmentation already done, skipping.") + return + try: + self.compute_vad() + self.segment_vad() + with self.session() as session: + session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update( + {"done": True} + ) + session.commit() + except Exception as e: + with self.session() as session: + session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update( + {"dirty": True} + ) + session.commit() + if isinstance(e, KaldiProcessingError): + log_kaldi_errors(e.error_logs) + e.update_log_file() + raise + + def export_files(self, output_directory: str, output_format: Optional[str] = None) -> None: + """ + Export the results of segmentation as TextGrids + + Parameters + ---------- + output_directory: str + Directory to save segmentation TextGrids + output_format: str, optional + Format to force output files into + """ + if output_format is None: + output_format = TextFileType.TEXTGRID.value + os.makedirs(output_directory, exist_ok=True) + with self.session() as session: + for f in session.query(File).options( + selectinload(File.utterances).joinedload(Utterance.speaker, innerjoin=True), + joinedload(File.sound_file, innerjoin=True), + joinedload(File.text_file), + ): + f.save(output_directory, output_format=output_format) + + def segment_utterance(self, utterance_id: int, allow_empty: bool = True): + with self.session() as session: + utterance = full_load_utterance(session, utterance_id) + + new_utterances = segment_utterance( + utterance.to_kalpy().segment, + None, + self.segmentation_options, + mfcc_options=self.mfcc_options, + vad_options=self.vad_options, + allow_empty=allow_empty, + ) + return new_utterances + + +class SpeechbrainVadSegmenter( + VadConfigMixin, + AcousticCorpusMixin, + FileExporterMixin, + SpeechbrainSegmenterMixin, + TopLevelMfaWorker, +): + """ + Class for performing speaker classification, parameters are passed to + `speechbrain.pretrained.interfaces.VAD.get_speech_segments + `_ + + Parameters + ---------- + segment_padding: float + Size of padding on both ends of a segment + large_chunk_size: float + Size (in seconds) of the large chunks that are read sequentially + from the input audio file. + small_chunk_size: float + Size (in seconds) of the small chunks extracted from the large ones. + The audio signal is processed in parallel within the small chunks. + Note that large_chunk_size/small_chunk_size must be an integer. + overlap_small_chunk: bool + If True, it creates overlapped small chunks (with 50% overal). + The probabilities of the overlapped chunks are combined using + hamming windows. + apply_energy_VAD: bool + If True, a energy-based VAD is used on the detected speech segments. + The neural network VAD often creates longer segments and tends to + merge close segments together. The energy VAD post-processes can be + useful for having a fine-grained voice activity detection. + The energy thresholds is managed by activation_th and + deactivation_th (see below). + double_check: bool + If True, double checks (using the neural VAD) that the candidate + speech segments actually contain speech. A threshold on the mean + posterior probabilities provided by the neural network is applied + based on the speech_th parameter (see below). + activation_th: float + Threshold of the neural posteriors above which starting a speech segment. + deactivation_th: float + Threshold of the neural posteriors below which ending a speech segment. + en_activation_th: float + A new speech segment is started it the energy is above activation_th. + This is active only if apply_energy_VAD is True. + en_deactivation_th: float + The segment is considered ended when the energy is <= deactivation_th. + This is active only if apply_energy_VAD is True. + speech_th: float + Threshold on the mean posterior probability within the candidate + speech segment. Below that threshold, the segment is re-assigned to + a non-speech region. This is active only if double_check is True. + close_th: float + If the distance between boundaries is smaller than close_th, the + segments will be merged. + len_th: float + If the length of the segment is smaller than len_th, the segments + will be merged. + """ + + def __init__( + self, + **kwargs, + ): + super().__init__(**kwargs) + self.transcriptions_required = False + + @classmethod + def parse_parameters( + cls, + config_path: Optional[Path] = None, + args: Optional[Dict[str, typing.Any]] = None, + unknown_args: Optional[typing.Iterable[str]] = None, + ) -> MetaDict: + """ + Parse parameters for segmentation from a config path or command-line arguments + + Parameters + ---------- + config_path: :class:`~pathlib.Path` + Config path + args: dict[str, Any] + Parsed arguments + unknown_args: list[str] + Optional list of arguments that were not parsed + + Returns + ------- + dict[str, Any] + Configuration parameters + """ + global_params = {} + if config_path and os.path.exists(config_path): + data = load_configuration(config_path) + for k, v in data.items(): + if k == "features": + if "type" in v: + v["feature_type"] = v["type"] + del v["type"] + global_params.update(v) + else: + if v is None and k in cls.nullable_fields: + v = [] + global_params[k] = v + global_params.update(cls.parse_args(args, unknown_args)) + return global_params + + def segment_vad(self) -> None: """ Run segmentation based off of VAD. @@ -258,38 +445,44 @@ def segment_vad_mfa(self) -> None: Job method for generating arguments for helper function """ - arguments = self.segment_vad_arguments() old_utts = set() new_utts = [] - - with self.session() as session: - utterances = session.query( - Utterance.id, Utterance.channel, Utterance.speaker_id, Utterance.file_id + kwargs = self.segmentation_options + with tqdm( + total=self.num_utterances, disable=config.QUIET + ) as pbar, self.session() as session: + utt_index = session.query(sqlalchemy.func.max(Utterance.id)).scalar() + if not utt_index: + utt_index = 0 + utt_index += 1 + files: List[File] = ( + session.query(File, Utterance) + .options(joinedload(File.sound_file)) + .join(Utterance.file) ) - utterance_cache = {} - for u_id, channel, speaker_id, file_id in utterances: - utterance_cache[u_id] = (channel, speaker_id, file_id) - for utt, segments in run_kaldi_function( - SegmentVadFunction, arguments, total_count=self.num_utterances - ): - old_utts.add(utt) - channel, speaker_id, file_id = utterance_cache[utt] + for f, u in files: + audio = f.sound_file.load_audio() + segments = self.vad_model.segment_utterance(audio, **kwargs) for seg in segments: + old_utts.add(u.id) new_utts.append( { + "id": utt_index, "begin": seg.begin, "end": seg.end, "text": "speech", - "speaker_id": speaker_id, - "file_id": file_id, + "speaker_id": u.speaker_id, + "file_id": u.file_id, "oovs": "", "normalized_text": "", "features": "", "in_subset": False, "ignored": False, - "channel": channel, + "channel": u.channel, } ) + utt_index += 1 + pbar.update(1) session.query(Utterance).filter(Utterance.id.in_(old_utts)).delete() session.bulk_insert_mappings( Utterance, new_utts, return_defaults=False, render_nulls=True @@ -303,11 +496,8 @@ def setup(self) -> None: log_dir = self.working_directory.joinpath("log") os.makedirs(log_dir, exist_ok=True) try: - if self.speechbrain: - self.initialize_database() - self._load_corpus() - else: - self.load_corpus() + self.initialize_database() + self._load_corpus() except Exception as e: if isinstance(e, KaldiProcessingError): log_kaldi_errors(e.error_logs) @@ -330,11 +520,7 @@ def segment(self) -> None: logger.info("Segmentation already done, skipping.") return try: - if not self.speechbrain: - self.compute_vad() - self.segment_vad_mfa() - else: - self.segment_vad_speechbrain() + self.segment_vad() with self.session() as session: session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update( {"done": True} @@ -379,10 +565,8 @@ def segment_utterance(self, utterance_id: int, allow_empty: bool = True): new_utterances = segment_utterance( utterance.to_kalpy().segment, - self.vad_model if self.speechbrain else None, + self.vad_model, self.segmentation_options, - mfcc_options=self.mfcc_options if not self.speechbrain else None, - vad_options=self.vad_options if not self.speechbrain else None, allow_empty=allow_empty, ) return new_utterances @@ -400,18 +584,12 @@ def __init__(self, acoustic_model_path: Path = None, **kwargs): def setup(self) -> None: TopLevelMfaWorker.setup(self) - self.create_new_current_workflow(WorkflowType.segmentation) self.setup_acoustic_model() - self.dictionary_setup() - self._load_corpus() - self.initialize_jobs() - self.normalize_text() - self.write_lexicon_information(write_disambiguation=False) def setup_acoustic_model(self): From eae38e0d8deb26e222e109217d6de2458086f528 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sat, 28 Sep 2024 12:40:07 -0700 Subject: [PATCH 08/16] Optimizations for training acoustic models --- docs/source/changelog/changelog_3.0.rst | 48 ------------------- docs/source/changelog/changelog_3.1.rst | 48 +++++++++++++++++++ docs/source/changelog/changelog_3.2.rst | 16 +++++++ docs/source/changelog/index.md | 2 + .../acoustic_modeling/trainer.py | 12 ++++- .../alignment/multiprocessing.py | 7 +-- .../corpus/acoustic_corpus.py | 39 ++++++++++++++- montreal_forced_aligner/corpus/base.py | 12 +++-- 8 files changed, 127 insertions(+), 57 deletions(-) create mode 100644 docs/source/changelog/changelog_3.1.rst create mode 100644 docs/source/changelog/changelog_3.2.rst diff --git a/docs/source/changelog/changelog_3.0.rst b/docs/source/changelog/changelog_3.0.rst index 74903e4a..4eeb4dcd 100644 --- a/docs/source/changelog/changelog_3.0.rst +++ b/docs/source/changelog/changelog_3.0.rst @@ -5,54 +5,6 @@ 3.0 Changelog ************* -3.2.0 ------ - -- Add support for transcription via whisperx and speechbrain models -- Update text normalization to normalize to decomposed forms -- Compatibility with Kalpy 0.6.7 - -3.1.4 ------ - -- Optimized :code:`mfa g2p` to better use multiple processes -- Added :code:`--export_scores` to :code:`mfa g2p` for adding a column representing the final weights of the generated pronunciations -- Added :code:`--output_directory` to :code:`mfa validate` to save generated validation files rather than the temporary directory -- Fixed a bug in cutoff modeling that was preventing them from being properly parsed - -3.1.3 ------ - -- Fixed an issue where silence probability being zero was not correctly removing silence -- Compatibility with kalpy v0.6.5 -- Added API functionality for verifying transcripts with interjection words in alignment -- Fixed an error in fine tuning that generated nonsensical boundaries - -3.1.2 ------ - -- Fixed a bug where hidden files and folders would be parsed as corpus data -- Fixed a bug where validation would not respect :code:`--no_final_clean` -- Fixed a rare crash in training when a job would not have utterances assigned to it -- Fixed a bug where MFA would mistakenly report a dictionary and acoustic model phones did not match for older versions - -3.1.1 ------ - -- Fixed an issue with TextGrids missing intervals - -3.1.0 ------ - -- Fixed a bug where cutoffs were not properly modelled -- Added additional filter on create subset to not include utterances with cutoffs in smaller subsets -- Added the ability to specify HMM topologies for phones -- Fixed issues caused by validators not cleaning up temporary files and databases -- Added support for default and nonnative dictionaries generated from other dictionaries -- Restricted initial training rounds to exclude default and nonnative dictionaries -- Changed clustering of phones to not mix silence and non-silence phones -- Optimized textgrid export -- Added better memory management for collecting alignments 3.0.8 ----- diff --git a/docs/source/changelog/changelog_3.1.rst b/docs/source/changelog/changelog_3.1.rst new file mode 100644 index 00000000..67b5c7c4 --- /dev/null +++ b/docs/source/changelog/changelog_3.1.rst @@ -0,0 +1,48 @@ + +.. _changelog_3.1: + +************* +3.1 Changelog +************* + +3.1.4 +----- + +- Optimized :code:`mfa g2p` to better use multiple processes +- Added :code:`--export_scores` to :code:`mfa g2p` for adding a column representing the final weights of the generated pronunciations +- Added :code:`--output_directory` to :code:`mfa validate` to save generated validation files rather than the temporary directory +- Fixed a bug in cutoff modeling that was preventing them from being properly parsed + +3.1.3 +----- + +- Fixed an issue where silence probability being zero was not correctly removing silence +- Compatibility with kalpy v0.6.5 +- Added API functionality for verifying transcripts with interjection words in alignment +- Fixed an error in fine tuning that generated nonsensical boundaries + +3.1.2 +----- + +- Fixed a bug where hidden files and folders would be parsed as corpus data +- Fixed a bug where validation would not respect :code:`--no_final_clean` +- Fixed a rare crash in training when a job would not have utterances assigned to it +- Fixed a bug where MFA would mistakenly report a dictionary and acoustic model phones did not match for older versions + +3.1.1 +----- + +- Fixed an issue with TextGrids missing intervals + +3.1.0 +----- + +- Fixed a bug where cutoffs were not properly modelled +- Added additional filter on create subset to not include utterances with cutoffs in smaller subsets +- Added the ability to specify HMM topologies for phones +- Fixed issues caused by validators not cleaning up temporary files and databases +- Added support for default and nonnative dictionaries generated from other dictionaries +- Restricted initial training rounds to exclude default and nonnative dictionaries +- Changed clustering of phones to not mix silence and non-silence phones +- Optimized textgrid export +- Added better memory management for collecting alignments diff --git a/docs/source/changelog/changelog_3.2.rst b/docs/source/changelog/changelog_3.2.rst new file mode 100644 index 00000000..4e416ac2 --- /dev/null +++ b/docs/source/changelog/changelog_3.2.rst @@ -0,0 +1,16 @@ + +.. _changelog_3.2: + +************* +3.2 Changelog +************* + +3.2.0 +----- + +- Added :code:`--subset_word_count` parameter to :ref:`train_acoustic_model` to add a minimum word count for an utterance to be included in training subsets +- Added :code:`--minimum_utterance_length` parameter to :ref:`train_acoustic_model` to add a minimum word count for an utterance to be included in training at all +- Improved memory usage in compiling training graphs for initial subsets +- Add support for transcription via whisperx and speechbrain models +- Update text normalization to normalize to decomposed forms +- Compatibility with Kalpy 0.6.7 diff --git a/docs/source/changelog/index.md b/docs/source/changelog/index.md index 1dc3e6a0..1e0a6e2b 100644 --- a/docs/source/changelog/index.md +++ b/docs/source/changelog/index.md @@ -53,6 +53,8 @@ :hidden: :maxdepth: 1 +changelog_3.2.rst +changelog_3.1.rst news_3.0.rst changelog_3.0.rst changelog_2.2.rst diff --git a/montreal_forced_aligner/acoustic_modeling/trainer.py b/montreal_forced_aligner/acoustic_modeling/trainer.py index 745bea45..cb5bb611 100644 --- a/montreal_forced_aligner/acoustic_modeling/trainer.py +++ b/montreal_forced_aligner/acoustic_modeling/trainer.py @@ -153,6 +153,8 @@ def __init__( training_configuration: List[Tuple[str, Dict[str, Any]]] = None, phone_set_type: str = None, model_version: str = None, + subset_word_count: int = 3, + minimum_utterance_length: int = 2, **kwargs, ): self.param_dict = { @@ -164,6 +166,7 @@ def __init__( } self.final_identifier = None self.current_subset: int = 0 + self.subset_word_count = subset_word_count self.current_aligner: Optional[AcousticModelTrainingMixin] = None self.current_trainer: Optional[AcousticModelTrainingMixin] = None self.current_acoustic_model: Optional[AcousticModel] = None @@ -184,6 +187,7 @@ def __init__( self.final_alignment = True self.model_version = model_version self.boost_silence = 1.5 + self.minimum_utterance_length = minimum_utterance_length @classmethod def default_training_configurations(cls) -> List[Tuple[str, Dict[str, Any]]]: @@ -335,6 +339,12 @@ def filter_training_utterances(self): update_mapping.append({"id": u_id, "ignored": True}) continue words = text.split() + if ( + self.minimum_utterance_length > 1 + and len(words) < self.minimum_utterance_length + ): + update_mapping.append({"id": u_id, "ignored": True}) + continue if any(x in word_mapping for x in words): continue update_mapping.append({"id": u_id, "ignored": True}) @@ -629,7 +639,7 @@ def train(self) -> None: new_phone_lm_path = os.path.join(previous.working_directory, "phone_lm.fst") if not os.path.exists(new_phone_lm_path) and os.path.exists(phone_lm_path): shutil.copyfile(phone_lm_path, new_phone_lm_path) - logger.info(f"Completed training in {time.time()-begin} seconds!") + logger.info(f"Completed training in {time.time() - begin} seconds!") def transition_acc_arguments(self) -> List[TransitionAccArguments]: """ diff --git a/montreal_forced_aligner/alignment/multiprocessing.py b/montreal_forced_aligner/alignment/multiprocessing.py index 06775ad7..8b02a278 100644 --- a/montreal_forced_aligner/alignment/multiprocessing.py +++ b/montreal_forced_aligner/alignment/multiprocessing.py @@ -459,9 +459,9 @@ def _run(self): self.tree_path, lexicon, use_g2p=self.use_g2p, - batch_size=1000 + batch_size=500 if workflow.workflow_type is not WorkflowType.transcript_verification - else 500, + else 250, ) graph_logger.debug(f"Set up took {time.time() - begin} seconds") query = ( @@ -484,7 +484,7 @@ def _run(self): ) graph_logger.debug(f"Total compilation time: {time.time() - begin} seconds") del compiler - del self.lexicon_compilers + del lexicon class AccStatsFunction(KaldiFunction): @@ -1560,6 +1560,7 @@ def _run(self) -> None: pass alignment_archive.close() extraction_logger.debug("Finished ali first pass") + del lexicon_compiler extraction_logger.debug("Finished extraction") diff --git a/montreal_forced_aligner/corpus/acoustic_corpus.py b/montreal_forced_aligner/corpus/acoustic_corpus.py index a8ec6d9b..73822b11 100644 --- a/montreal_forced_aligner/corpus/acoustic_corpus.py +++ b/montreal_forced_aligner/corpus/acoustic_corpus.py @@ -39,7 +39,7 @@ AcousticDirectoryParser, CorpusProcessWorker, ) -from montreal_forced_aligner.data import DatabaseImportData, PhoneType, WorkflowType +from montreal_forced_aligner.data import DatabaseImportData, PhoneType, WordType, WorkflowType from montreal_forced_aligner.db import ( Corpus, CorpusWorkflow, @@ -50,6 +50,7 @@ Speaker, TextFile, Utterance, + Word, bulk_update, ) from montreal_forced_aligner.dictionary.mixins import DictionaryMixin @@ -1129,6 +1130,42 @@ def load_corpus(self) -> None: logger.debug(f"Setting up corpus took {time.time() - all_begin:.3f} seconds") + def subset_lexicon(self) -> None: + included_words = set() + with self.session() as session: + corpus = session.query(Corpus).first() + if corpus.current_subset > 0: + subset_utterances = ( + session.query(Utterance.normalized_text) + .filter(Utterance.in_subset == True) # noqa + .filter(Utterance.ignored == False) # noqa + ) + for (u_text,) in subset_utterances: + included_words.update(u_text.split()) + session.execute( + sqlalchemy.update(Word) + .where(Word.word_type == WordType.speech) + .values(included=False) + ) + session.flush() + session.execute( + sqlalchemy.update(Word) + .where(Word.word_type == WordType.speech) + .where(Word.count > self.oov_count_threshold) + .where(Word.word.in_(included_words)) + .values(included=True) + ) + else: + session.execute( + sqlalchemy.update(Word) + .where(Word.word_type == WordType.speech) + .where(Word.count > self.oov_count_threshold) + .values(included=True) + ) + + session.commit() + self.write_lexicon_information() + class AcousticCorpus(AcousticCorpusMixin, DictionaryMixin, MfaWorker): """ diff --git a/montreal_forced_aligner/corpus/base.py b/montreal_forced_aligner/corpus/base.py index 1bd996ab..ce80fb82 100644 --- a/montreal_forced_aligner/corpus/base.py +++ b/montreal_forced_aligner/corpus/base.py @@ -791,7 +791,7 @@ def normalize_text(self) -> None: import traceback exc_type, exc_value, exc_traceback = sys.exc_info() - print( + logger.debug( "\n".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) ) raise @@ -1200,7 +1200,8 @@ def create_subset(self, subset: int) -> None: cutoff_pattern = "<(cutoff|hes)" def add_filters(query): - multiword_pattern = r"\s\S+\s" + subset_word_count = getattr(self, "subset_word_count", 3) + multiword_pattern = rf"(\s\S+){{{subset_word_count},}}" filtered = ( query.filter( Utterance.normalized_text.op("~")(multiword_pattern) @@ -1488,7 +1489,7 @@ def add_filters(query): log_dir = subset_directory.joinpath("log") os.makedirs(log_dir, exist_ok=True) - logger.debug(f"Setting subset flags took {time.time()-begin} seconds") + logger.debug(f"Setting subset flags took {time.time() - begin} seconds") with self.session() as session: jobs = ( session.query(Job) @@ -1507,7 +1508,6 @@ def add_filters(query): ) for j in self._jobs ] - for _ in run_kaldi_function(ExportKaldiFilesFunction, arguments, total_count=subset): pass @@ -1559,10 +1559,14 @@ def subset_directory(self, subset: typing.Optional[int]) -> Path: c.current_subset = subset session.commit() if subset is None or subset >= self.num_utterances or subset <= 0: + if hasattr(self, "subset_lexicon"): + self.subset_lexicon() return self.split_directory directory = self.corpus_output_directory.joinpath(f"subset_{subset}") if not os.path.exists(directory): self.create_subset(subset) + if hasattr(self, "subset_lexicon"): + self.subset_lexicon() return directory def get_latest_workflow_run(self, workflow: WorkflowType, session: Session) -> CorpusWorkflow: From 74eb2bd91fd949d780145ffa17c23300bd81916d Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sat, 28 Sep 2024 22:47:54 -0700 Subject: [PATCH 09/16] Update environment --- environment.yml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/environment.yml b/environment.yml index c442fd11..9b35aad3 100644 --- a/environment.yml +++ b/environment.yml @@ -50,11 +50,6 @@ dependencies: - sudachipy - sudachidict-core - spacy-pkuseg - - pytorch - - torchaudio - # WhisperX dependencies - - cudnn =8 - - transformers - pip: - build - twine @@ -64,8 +59,3 @@ dependencies: - pythainlp - hanziconv - dragonmapper - # Speechbrain dependencies - - speechbrain - - pygtrie - # WhisperX dependencies - - whisperx From ac161c579c932b34f55b50e1a148321f3fd77695 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sun, 29 Sep 2024 13:06:42 -0700 Subject: [PATCH 10/16] Update micromamba version for github actions --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 1c3a2a9a..fb168b78 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -37,7 +37,7 @@ jobs: fetch-depth: 0 - name: Install Conda environment with Micromamba - uses: mamba-org/setup-micromamba@v1 + uses: mamba-org/setup-micromamba@v1.9.0 with: environment-file: environment.yml environment-name: mfa From 36d379e54e3bcdf0ec9043076d3640114e8267bb Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sun, 29 Sep 2024 13:34:12 -0700 Subject: [PATCH 11/16] Update github actions environment --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fb168b78..9a98af04 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,7 +19,7 @@ jobs: include: - os: ubuntu-latest label: linux-64 - prefix: /usr/share/miniconda3/envs/my-env + prefix: /usr/share/miniconda3/envs/mfa #- os: macos-latest # label: osx-64 From 95fb7dcb582f71f885bb8f502d1089c8bc0056cc Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sun, 29 Sep 2024 13:43:33 -0700 Subject: [PATCH 12/16] Switch to using miniforge for gha --- .github/workflows/main.yml | 10 +++------ github_environment.yml | 43 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 7 deletions(-) create mode 100644 github_environment.yml diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9a98af04..0acf3912 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,7 +19,6 @@ jobs: include: - os: ubuntu-latest label: linux-64 - prefix: /usr/share/miniconda3/envs/mfa #- os: macos-latest # label: osx-64 @@ -37,13 +36,10 @@ jobs: fetch-depth: 0 - name: Install Conda environment with Micromamba - uses: mamba-org/setup-micromamba@v1.9.0 + uses: conda-incubator/setup-miniconda@v3 with: - environment-file: environment.yml - environment-name: mfa - create-args: >- - python=3.9 - cache-environment: true + environment-file: github_environment.yml + miniforge-version: latest - name: Configure mfa shell: bash -l {0} diff --git a/github_environment.yml b/github_environment.yml new file mode 100644 index 00000000..f42ef816 --- /dev/null +++ b/github_environment.yml @@ -0,0 +1,43 @@ +channels: + - conda-forge +dependencies: + - python>=3.8 + - numpy + - librosa + - pysoundfile + - tqdm + - requests + - pyyaml + - dataclassy + - kaldi=*=*cpu* + - scipy + - pynini + - openfst=1.8.3 + - scikit-learn<1.3 + - hdbscan + - baumwelch + - ngram + - praatio=6.0.0 + - biopython=1.79 + - sqlalchemy>=2.0 + - pgvector + - pgvector-python + - sqlite + - postgresql + - psycopg2 + - click + - setuptools_scm + - pytest + - pytest-mypy + - pytest-cov + - pytest-timeout + - mock + - coverage + - coveralls + - interrogate + - kneed + - matplotlib + - seaborn + - rich + - rich-click + - kalpy From 1921766c586baf803160f557c179053322854942 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sun, 29 Sep 2024 13:46:25 -0700 Subject: [PATCH 13/16] Fix for torch import error --- montreal_forced_aligner/online/transcription.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/montreal_forced_aligner/online/transcription.py b/montreal_forced_aligner/online/transcription.py index cf07b198..30f96af9 100644 --- a/montreal_forced_aligner/online/transcription.py +++ b/montreal_forced_aligner/online/transcription.py @@ -4,7 +4,6 @@ import typing import numpy as np -import torch from _kalpy.fstext import ConstFst from _kalpy.matrix import DoubleMatrix, FloatMatrix from kalpy.data import Segment @@ -129,6 +128,8 @@ def transcribe_utterance_online_speechbrain( raise Exception( "Could not import speechbrain, please ensure it is installed via `pip install speechbrain`" ) + import torch + segment = utterance.segment waveform = segment.load_audio() waveform = model.audio_normalizer(waveform, 16000).unsqueeze(0) From 3bf76effb1a4dee6e2f9349dd8c08300609aa6f8 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Sun, 29 Sep 2024 14:27:55 -0700 Subject: [PATCH 14/16] Update segmentation test --- .../command_line/create_segments.py | 5 --- tests/test_commandline_create_segments.py | 37 ------------------- 2 files changed, 42 deletions(-) diff --git a/montreal_forced_aligner/command_line/create_segments.py b/montreal_forced_aligner/command_line/create_segments.py index 693caa08..af7ef7f2 100644 --- a/montreal_forced_aligner/command_line/create_segments.py +++ b/montreal_forced_aligner/command_line/create_segments.py @@ -119,11 +119,6 @@ def create_segments_vad_cli(context, **kwargs) -> None: default="long_textgrid", type=click.Choice(["long_textgrid", "short_textgrid", "json", "csv"]), ) -@click.option( - "--speechbrain/--no_speechbrain", - "speechbrain", - help="Flag for using SpeechBrain's pretrained VAD model", -) @click.option( "--cuda/--no_cuda", "cuda", diff --git a/tests/test_commandline_create_segments.py b/tests/test_commandline_create_segments.py index bad52c07..92bb14c2 100644 --- a/tests/test_commandline_create_segments.py +++ b/tests/test_commandline_create_segments.py @@ -84,42 +84,6 @@ def test_create_segments_transcripts( temp_dir, basic_segment_config_path, db_setup, -): - output_path = generated_dir.joinpath("segment_output") - command = [ - "segment", - basic_corpus_dir, - english_us_mfa_reduced_dict, - english_mfa_acoustic_model, - output_path, - "-q", - "--clean", - "--no_debug", - "-v", - "--config_path", - basic_segment_config_path, - ] - command = [str(x) for x in command] - result = click.testing.CliRunner(mix_stderr=False).invoke( - mfa_cli, command, catch_exceptions=True - ) - print(result.stdout) - print(result.stderr) - if result.exception: - print(result.exc_info) - raise result.exception - assert not result.return_value - assert os.path.exists(os.path.join(output_path, "michael", "acoustic_corpus.TextGrid")) - - -def test_create_segments_transcripts_speechbrain( - basic_corpus_dir, - english_mfa_acoustic_model, - english_us_mfa_reduced_dict, - generated_dir, - temp_dir, - basic_segment_config_path, - db_setup, ): if not FOUND_SPEECHBRAIN: pytest.skip("SpeechBrain not installed") @@ -134,7 +98,6 @@ def test_create_segments_transcripts_speechbrain( "--clean", "--no_debug", "-v", - "--speechbrain", "--no_use_mp", "--config_path", basic_segment_config_path, From 2eaaa83857332728469a85bdb2e4c02792d3479c Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Mon, 30 Sep 2024 13:42:25 -0700 Subject: [PATCH 15/16] Fix bug in transcript testing --- .../corpus/acoustic_corpus.py | 17 ++++++++--------- .../validation/corpus_validator.py | 19 ++++++++++--------- tests/test_acoustic_modeling.py | 4 ++++ 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/montreal_forced_aligner/corpus/acoustic_corpus.py b/montreal_forced_aligner/corpus/acoustic_corpus.py index 73822b11..1699660d 100644 --- a/montreal_forced_aligner/corpus/acoustic_corpus.py +++ b/montreal_forced_aligner/corpus/acoustic_corpus.py @@ -1130,10 +1130,16 @@ def load_corpus(self) -> None: logger.debug(f"Setting up corpus took {time.time() - all_begin:.3f} seconds") - def subset_lexicon(self) -> None: + def subset_lexicon(self, write_disambiguation: Optional[bool] = False) -> None: included_words = set() with self.session() as session: corpus = session.query(Corpus).first() + session.execute( + sqlalchemy.update(Word) + .where(Word.word_type == WordType.speech) + .values(included=False) + ) + session.flush() if corpus.current_subset > 0: subset_utterances = ( session.query(Utterance.normalized_text) @@ -1142,12 +1148,6 @@ def subset_lexicon(self) -> None: ) for (u_text,) in subset_utterances: included_words.update(u_text.split()) - session.execute( - sqlalchemy.update(Word) - .where(Word.word_type == WordType.speech) - .values(included=False) - ) - session.flush() session.execute( sqlalchemy.update(Word) .where(Word.word_type == WordType.speech) @@ -1162,9 +1162,8 @@ def subset_lexicon(self) -> None: .where(Word.count > self.oov_count_threshold) .values(included=True) ) - session.commit() - self.write_lexicon_information() + self.write_lexicon_information(write_disambiguation=write_disambiguation) class AcousticCorpus(AcousticCorpusMixin, DictionaryMixin, MfaWorker): diff --git a/montreal_forced_aligner/validation/corpus_validator.py b/montreal_forced_aligner/validation/corpus_validator.py index fc6bf8d3..def044e3 100644 --- a/montreal_forced_aligner/validation/corpus_validator.py +++ b/montreal_forced_aligner/validation/corpus_validator.py @@ -371,6 +371,7 @@ def test_utterance_transcriptions(self, output_directory: Path = None) -> None: output_directory = self.output_directory os.makedirs(output_directory, exist_ok=True) try: + self.subset_lexicon(write_disambiguation=True) self.train_speaker_lms() self.transcribe(WorkflowType.per_speaker_transcription) @@ -378,25 +379,25 @@ def test_utterance_transcriptions(self, output_directory: Path = None) -> None: logger.info("Test transcriptions") ser, wer, cer = self.compute_wer() if ser < 0.3: - logger.info(f"{ser*100:.2f}% sentence error rate") + logger.info(f"{ser * 100:.2f}% sentence error rate") elif ser < 0.8: - logger.warning(f"{ser*100:.2f}% sentence error rate") + logger.warning(f"{ser * 100:.2f}% sentence error rate") else: - logger.error(f"{ser*100:.2f}% sentence error rate") + logger.error(f"{ser * 100:.2f}% sentence error rate") if wer < 0.25: - logger.info(f"{wer*100:.2f}% word error rate") + logger.info(f"{wer * 100:.2f}% word error rate") elif wer < 0.75: - logger.warning(f"{wer*100:.2f}% word error rate") + logger.warning(f"{wer * 100:.2f}% word error rate") else: - logger.error(f"{wer*100:.2f}% word error rate") + logger.error(f"{wer * 100:.2f}% word error rate") if cer < 0.25: - logger.info(f"{cer*100:.2f}% character error rate") + logger.info(f"{cer * 100:.2f}% character error rate") elif cer < 0.75: - logger.warning(f"{cer*100:.2f}% character error rate") + logger.warning(f"{cer * 100:.2f}% character error rate") else: - logger.error(f"{cer*100:.2f}% character error rate") + logger.error(f"{cer * 100:.2f}% character error rate") self.save_transcription_evaluation(output_directory) out_path = os.path.join(output_directory, "transcription_evaluation.csv") diff --git a/tests/test_acoustic_modeling.py b/tests/test_acoustic_modeling.py index 8bc7dfa9..de789afc 100644 --- a/tests/test_acoustic_modeling.py +++ b/tests/test_acoustic_modeling.py @@ -73,6 +73,7 @@ def test_trainer(basic_dict_path, temp_dir, basic_corpus_dir): a.cleanup() +@pytest.mark.skip def test_basic_mono( mixed_dict_path, basic_corpus_dir, @@ -108,6 +109,7 @@ def test_basic_mono( a.clean_working_directory() +@pytest.mark.skip def test_pronunciation_training( mixed_dict_path, basic_corpus_dir, @@ -155,6 +157,7 @@ def test_pronunciation_training( a.clean_working_directory() +@pytest.mark.skip def test_pitch_feature_training( basic_dict_path, basic_corpus_dir, pitch_train_config_path, db_setup ): @@ -172,6 +175,7 @@ def test_pitch_feature_training( a.clean_working_directory() +@pytest.mark.skip def test_basic_lda(basic_dict_path, basic_corpus_dir, lda_train_config_path, db_setup): a = TrainableAligner( corpus_directory=basic_corpus_dir, From 2f4802e05230f22cd60ca4a19bc1b318be7b7922 Mon Sep 17 00:00:00 2001 From: Michael McAuliffe Date: Mon, 30 Sep 2024 14:34:35 -0700 Subject: [PATCH 16/16] Fixing adaptation for older models --- .../acoustic_modeling/base.py | 52 +++++++---------- montreal_forced_aligner/alignment/adapting.py | 58 +++++++------------ montreal_forced_aligner/ivector/trainer.py | 7 +-- 3 files changed, 47 insertions(+), 70 deletions(-) diff --git a/montreal_forced_aligner/acoustic_modeling/base.py b/montreal_forced_aligner/acoustic_modeling/base.py index 384daa69..0b16d722 100644 --- a/montreal_forced_aligner/acoustic_modeling/base.py +++ b/montreal_forced_aligner/acoustic_modeling/base.py @@ -2,7 +2,6 @@ from __future__ import annotations import logging -import os import time from abc import abstractmethod from pathlib import Path @@ -131,9 +130,7 @@ def acc_stats_arguments(self) -> List[AccStatsArguments]: AccStatsArguments( j.id, self.session if config.USE_THREADING else self.db_string, - os.path.join( - self.working_directory, "log", f"acc.{self.iteration}.{j.id}.log" - ), + self.working_log_directory.joinpath(f"acc.{self.iteration}.{j.id}.log"), self.working_directory, self.model_path, ) @@ -213,7 +210,7 @@ def initialize_training(self) -> None: ) self.subset = 0 self.worker.current_subset = 0 - os.makedirs(self.working_log_directory, exist_ok=True) + self.working_log_directory.mkdir(parents=True, exist_ok=True) self._trainer_initialization() self.iteration = 1 self.worker.current_trainer = self @@ -321,23 +318,23 @@ def acc_stats(self) -> None: train_logger.debug(f"Power: {self.power}") objf_impr, count = transition_model.mle_update(transition_accs) train_logger.debug( - f"Transition model update: Overall {objf_impr/count} " + f"Transition model update: Overall {objf_impr / count} " f"log-like improvement per frame over {count} frames." ) objf_impr, count = acoustic_model.mle_update( gmm_accs, mixup=self.current_gaussians, power=self.power ) train_logger.debug( - f"GMM update: Overall {objf_impr/count} " + f"GMM update: Overall {objf_impr / count} " f"objective function improvement per frame over {count} frames." ) tot_like = gmm_accs.TotLogLike() tot_t = gmm_accs.TotCount() train_logger.debug( - f"Average Likelihood per frame for iteration {self.iteration} = {tot_like/tot_t} " + f"Average Likelihood per frame for iteration {self.iteration} = {tot_like / tot_t} " f"over {tot_t} frames." ) - logger.debug(f"Log likelihood for iteration {self.iteration}: {tot_like/tot_t}") + logger.debug(f"Log likelihood for iteration {self.iteration}: {tot_like / tot_t}") write_gmm_model(str(self.next_model_path), transition_model, acoustic_model) def align_iteration(self) -> None: @@ -345,20 +342,20 @@ def align_iteration(self) -> None: begin = time.time() self.align_utterances(training=True) logger.debug( - f"Generating alignments for iteration {self.iteration} took {time.time()-begin} seconds" + f"Generating alignments for iteration {self.iteration} took {time.time() - begin} seconds" ) @property def initialized(self) -> bool: return ( - os.path.exists(self.working_directory.joinpath("1.mdl")) - or os.path.exists(self.working_directory.joinpath("final.mdl")) - or os.path.exists(self.working_directory.joinpath("done")) + self.working_directory.joinpath("1.mdl").exists() + or self.working_directory.joinpath("final.mdl").exists() + or self.working_directory.joinpath("done").exists() ) def train_iteration(self) -> None: """Perform an iteration of training""" - if os.path.exists(self.next_model_path): + if self.next_model_path.exists(): self.iteration += 1 if self.iteration <= self.final_gaussian_iteration: self.increment_gaussians() @@ -381,7 +378,7 @@ def train(self) -> None: :class:`~montreal_forced_aligner.exceptions.KaldiProcessingError` If there were any errors in running Kaldi binaries """ - os.makedirs(self.working_log_directory, exist_ok=True) + self.working_log_directory.mkdir(parents=True, exist_ok=True) wf = self.worker.current_workflow if wf.done: return @@ -419,27 +416,23 @@ def finalize_training(self) -> None: the model to be used in the next round alignment """ - os.rename( - self.working_directory.joinpath(f"{self.num_iterations+1}.mdl"), - self.working_directory.joinpath("final.mdl"), + self.working_directory.joinpath(f"{self.num_iterations + 1}.mdl").rename( + self.working_directory.joinpath("final.mdl") ) - ali_model_path = self.working_directory.joinpath(f"{self.num_iterations+1}.alimdl") - if os.path.exists(ali_model_path): - os.rename( - ali_model_path, - self.working_directory.joinpath("final.alimdl"), - ) + ali_model_path = self.working_directory.joinpath(f"{self.num_iterations + 1}.alimdl") + if ali_model_path.exists(): + ali_model_path.rename(self.working_directory.joinpath("final.alimdl")) self.export_model(self.exported_model_path) if not config.DEBUG: for i in range(1, self.num_iterations + 1): model_path = self.working_directory.joinpath(f"{i}.mdl") try: - os.remove(model_path) + model_path.unlink(missing_ok=True) except FileNotFoundError: pass - for file in os.listdir(self.working_directory): - if any(file.startswith(x) for x in ["fsts.", "trans.", "ali."]): - os.remove(self.working_directory.joinpath(file)) + for file in self.working_directory.iterdir(): + if any(file.name.startswith(x) for x in ["fsts.", "trans.", "ali."]): + file.unlink(missing_ok=True) wf = self.worker.current_workflow with self.session() as session: session.query(CorpusWorkflow).filter(CorpusWorkflow.id == wf.id).update({"done": True}) @@ -574,6 +567,5 @@ def export_model(self, output_model_path: Path) -> None: self.working_directory, self.worker.dictionary_base_names.values() ) if directory: - os.makedirs(directory, exist_ok=True) - basename, _ = os.path.splitext(output_model_path) + directory.mkdir(parents=True, exist_ok=True) acoustic_model.dump(output_model_path) diff --git a/montreal_forced_aligner/alignment/adapting.py b/montreal_forced_aligner/alignment/adapting.py index 3dab2f3b..b365d94f 100644 --- a/montreal_forced_aligner/alignment/adapting.py +++ b/montreal_forced_aligner/alignment/adapting.py @@ -2,7 +2,6 @@ from __future__ import annotations import logging -import os import shutil import time from pathlib import Path @@ -118,20 +117,20 @@ def acc_stats(self, alignment: bool = False) -> None: IsmoothStatsAmDiagGmmFromModel(acoustic_model, self.mapping_tau, gmm_accs) objf_impr, count = transition_model.mle_update(transition_accs) logger.debug( - f"Transition model update: Overall {objf_impr/count} " + f"Transition model update: Overall {objf_impr / count} " f"log-like improvement per frame over {count} frames." ) objf_impr, count = acoustic_model.mle_update( gmm_accs, update_flags_str="m", remove_low_count_gaussians=False ) logger.debug( - f"GMM update: Overall {objf_impr/count} " + f"GMM update: Overall {objf_impr / count} " f"objective function improvement per frame over {count} frames." ) tot_like = gmm_accs.TotLogLike() tot_t = gmm_accs.TotCount() logger.debug( - f"Average Likelihood per frame = {tot_like/tot_t} " f"over {tot_t} frames." + f"Average Likelihood per frame = {tot_like / tot_t} " f"over {tot_t} frames." ) write_gmm_model(str(final_mdl_path), transition_model, acoustic_model) @@ -157,7 +156,7 @@ def alignment_model_path(self) -> Path: """Current acoustic model path""" if self.current_workflow.workflow_type == WorkflowType.acoustic_model_adaptation: path = self.working_directory.joinpath("unadapted.alimdl") - if os.path.exists(path) and not getattr(self, "uses_speaker_adaptation", False): + if path.exists() and not getattr(self, "uses_speaker_adaptation", False): return path return self.model_path return super().alignment_model_path @@ -190,7 +189,7 @@ def train_map(self) -> None: """ begin = time.time() log_directory = self.working_log_directory - os.makedirs(log_directory, exist_ok=True) + log_directory.mkdir(parents=True, exist_ok=True) self.acc_stats(alignment=False) if self.uses_speaker_adaptation: @@ -204,19 +203,17 @@ def adapt(self) -> None: self.align() alignment_workflow = self.current_workflow self.create_new_current_workflow(WorkflowType.acoustic_model_adaptation) - for f in ["final.mdl", "final.alimdl"]: + for f in ["final.mdl", "final.alimdl", "tree", "lda.mat"]: + path = alignment_workflow.working_directory.joinpath(f) + new_path = self.working_directory.joinpath(f) + if f.startswith("final"): + new_path = new_path.with_stem("unadapted") + if not path.exists(): + continue shutil.copyfile( - os.path.join(alignment_workflow.working_directory, f), - self.working_directory.joinpath(f).with_stem("unadapted"), + path, + new_path, ) - shutil.copyfile( - os.path.join(alignment_workflow.working_directory, "tree"), - self.working_directory.joinpath("tree"), - ) - shutil.copyfile( - os.path.join(alignment_workflow.working_directory, "lda.mat"), - self.working_directory.joinpath("lda.mat"), - ) for j in self.jobs: old_paths = j.construct_path_dictionary( alignment_workflow.working_directory, "ali", "ark" @@ -224,28 +221,18 @@ def adapt(self) -> None: new_paths = j.construct_path_dictionary(self.working_directory, "ali", "ark") for k, v in old_paths.items(): shutil.copyfile(v, new_paths[k]) - os.makedirs(self.align_directory, exist_ok=True) + self.align_directory.mkdir(parents=True, exist_ok=True) try: logger.info("Adapting pretrained model...") self.train_map() self.export_model(self.working_log_directory.joinpath("acoustic_model.zip")) - shutil.copyfile( - self.working_directory.joinpath("final.mdl"), - os.path.join(self.align_directory, "final.mdl"), - ) - shutil.copyfile( - self.working_directory.joinpath("tree"), - os.path.join(self.align_directory, "tree"), - ) - if os.path.exists(self.working_directory.joinpath("final.alimdl")): - shutil.copyfile( - self.working_directory.joinpath("final.alimdl"), - os.path.join(self.align_directory, "final.alimdl"), - ) - if os.path.exists(self.working_directory.joinpath("lda.mat")): + for f in ["final.mdl", "final.alimdl", "tree", "lda.mat"]: + path = self.working_directory.joinpath(f) + if not path.exists(): + continue shutil.copyfile( - self.working_directory.joinpath("lda.mat"), - os.path.join(self.align_directory, "lda.mat"), + path, + self.align_directory.joinpath(f), ) wf = self.current_workflow with self.session() as session: @@ -317,6 +304,5 @@ def export_model(self, output_model_path: Path) -> None: acoustic_model.add_model(self.working_directory) acoustic_model.add_model(self.phones_dir) if directory: - os.makedirs(directory, exist_ok=True) - basename, _ = os.path.splitext(output_model_path) + directory.mkdir(parents=True, exist_ok=True) acoustic_model.dump(output_model_path) diff --git a/montreal_forced_aligner/ivector/trainer.py b/montreal_forced_aligner/ivector/trainer.py index 231b3ab2..ca972ca8 100644 --- a/montreal_forced_aligner/ivector/trainer.py +++ b/montreal_forced_aligner/ivector/trainer.py @@ -99,8 +99,7 @@ def export_model(self, output_model_path: Path) -> None: ivector_extractor.add_model(self.working_directory) if directory: os.makedirs(directory, exist_ok=True) - basename, _ = os.path.splitext(output_model_path) - ivector_extractor.dump(basename) + ivector_extractor.dump(output_model_path) class DubmTrainer(IvectorModelTrainingMixin): @@ -363,7 +362,7 @@ def finalize_training(self) -> None: """Finalize DUBM training""" final_dubm_path = self.working_directory.joinpath("final.dubm") shutil.copy( - self.working_directory.joinpath(f"{self.num_iterations+1}.dubm"), + self.working_directory.joinpath(f"{self.num_iterations + 1}.dubm"), final_dubm_path, ) # Update VAD with dubm likelihoods @@ -793,7 +792,7 @@ def train(self) -> None: self.set_current_workflow(trainer.identifier) trainer.train() previous = trainer - logger.info(f"Completed training in {time.time()-begin} seconds!") + logger.info(f"Completed training in {time.time() - begin} seconds!") def export_model(self, output_model_path: Path) -> None: """