From 39dce9e0bc41a208521c230c1a16da394e003784 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Fri, 8 Dec 2023 08:58:32 +0100 Subject: [PATCH] Revert "Revert "fix: make --continue_path work again (#131)"" This reverts commit 695a699cba94fbca7aa19f0f4195bf6826e314f9. --- tests/test_continue_train.py | 13 ++++++++++++- trainer/io.py | 5 ++++- trainer/trainer.py | 15 +++++++++++---- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/tests/test_continue_train.py b/tests/test_continue_train.py index 6bd158f..cc6632b 100644 --- a/tests/test_continue_train.py +++ b/tests/test_continue_train.py @@ -14,8 +14,19 @@ def test_continue_train(): continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth"))) - command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path}" + # Continue training from the best model + command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1" run_cli(command_continue) assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth"))) + + # Continue training from the last checkpoint + for best in glob.glob(os.path.join(continue_path, "best_model*")): + os.remove(best) + run_cli(command_continue) + + # Continue training from a specific checkpoint + restore_path = os.path.join(continue_path, "checkpoint_5.pth") + command_continue = f"python tests/utils/train_mnist.py --restore_path {restore_path}" + run_cli(command_continue) shutil.rmtree(continue_path) diff --git a/trainer/io.py b/trainer/io.py index 6e08aea..eb34082 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -180,7 +180,10 @@ def save_best_model( save_func=None, **kwargs, ): - if current_loss < best_loss: + use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None + if (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or ( + not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"] + ): best_model_name = f"best_model_{current_step}.pth" checkpoint_path = os.path.join(out_path, best_model_name) logger.info(" > BEST MODEL : %s", checkpoint_path) diff --git a/trainer/trainer.py b/trainer/trainer.py index cc74024..a62b2b1 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -451,7 +451,7 @@ def __init__( # pylint: disable=dangerous-default-value self.epochs_done = 0 self.restore_step = 0 self.restore_epoch = 0 - self.best_loss = float("inf") + self.best_loss = {"train_loss": float("inf"), "eval_loss": float("inf") if self.config.run_eval else None} self.train_loader = None self.test_loader = None self.eval_loader = None @@ -1724,8 +1724,15 @@ def _restore_best_loss(self): logger.info(" > Restoring best loss from %s ...", os.path.basename(self.args.best_path)) ch = load_fsspec(self.args.restore_path, map_location="cpu") if "model_loss" in ch: - self.best_loss = ch["model_loss"] - logger.info(" > Starting with loaded last best loss %f", self.best_loss) + if isinstance(ch["model_loss"], dict): + self.best_loss = ch["model_loss"] + # For backwards-compatibility: + elif isinstance(ch["model_loss"], float): + if self.config.run_eval: + self.best_loss = {"train_loss": None, "eval_loss": ch["model_loss"]} + else: + self.best_loss = {"train_loss": ch["model_loss"], "eval_loss": None} + logger.info(" > Starting with loaded last best loss %s", self.best_loss) def test(self, model=None, test_samples=None) -> None: """Run evaluation steps on the test data split. You can either provide the model and the test samples @@ -1907,7 +1914,7 @@ def save_best_model(self) -> None: # save the model and update the best_loss self.best_loss = save_best_model( - eval_loss if eval_loss else train_loss, + {"train_loss": train_loss, "eval_loss": eval_loss}, self.best_loss, self.config, self.model,