diff --git a/everyvoice/base_cli/callback.py b/everyvoice/base_cli/callback.py new file mode 100644 index 00000000..42fb70ef --- /dev/null +++ b/everyvoice/base_cli/callback.py @@ -0,0 +1,22 @@ +from typing import Any, Dict + +import pytorch_lightning as pl +from pytorch_lightning.callbacks import Callback +from typing_extensions import override + + +class ResetValidationDataloaderCallback(Callback): + """ + Reset the validation progress to allow resuming and validating a full + validation set and not just the first example in the validation set. + """ + + @override + def on_save_checkpoint( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + checkpoint: Dict[str, Any], + ) -> None: + batch_progress = trainer.fit_loop.epoch_loop.val_loop.batch_progress + batch_progress.reset() diff --git a/everyvoice/base_cli/helpers.py b/everyvoice/base_cli/helpers.py index be675def..c6efa365 100644 --- a/everyvoice/base_cli/helpers.py +++ b/everyvoice/base_cli/helpers.py @@ -168,6 +168,8 @@ def train_base_command( gradient_clip_val: float | None, model_kwargs={}, ): + from everyvoice.base_cli.callback import ResetValidationDataloaderCallback + config = load_config_base_command(model_config, config_args, config_file) save_configuration_to_log_dir(config) @@ -197,6 +199,7 @@ def train_base_command( **{"sub_dir": config.training.logger.sub_dir}, } ) + lr_monitor = LearningRateMonitor(logging_interval="step") logger.info("Starting training.") # This callback will always save the last checkpoint @@ -226,7 +229,12 @@ def train_base_command( max_steps=config.training.max_steps, check_val_every_n_epoch=config.training.check_val_every_n_epoch, val_check_interval=config.training.val_check_interval, - callbacks=[monitored_ckpt_callback, last_ckpt_callback, lr_monitor], + callbacks=[ + monitored_ckpt_callback, + last_ckpt_callback, + lr_monitor, + ResetValidationDataloaderCallback(), + ], strategy=strategy, num_nodes=nodes, detect_anomaly=False, # used for debugging, but triples training time diff --git a/everyvoice/model/feature_prediction/FastSpeech2_lightning b/everyvoice/model/feature_prediction/FastSpeech2_lightning index ec675cfa..b6e373d0 160000 --- a/everyvoice/model/feature_prediction/FastSpeech2_lightning +++ b/everyvoice/model/feature_prediction/FastSpeech2_lightning @@ -1 +1 @@ -Subproject commit ec675cfa2d374725c5f2c6397ecdeb29c923beeb +Subproject commit b6e373d0cf80ac687f78136b1e680f2bdd874e15