Skip to content

Commit

Permalink
Revert "Revert "fix: make --continue_path work again (#131)""
Browse files Browse the repository at this point in the history
This reverts commit 695a699.
  • Loading branch information
eginhard committed Dec 8, 2023
1 parent 7cfc3e3 commit 39dce9e
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
13 changes: 12 additions & 1 deletion tests/test_continue_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,19 @@ def test_continue_train():
continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime)
number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth")))

command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path}"
# Continue training from the best model
command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1"
run_cli(command_continue)

assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth")))

# Continue training from the last checkpoint
for best in glob.glob(os.path.join(continue_path, "best_model*")):
os.remove(best)
run_cli(command_continue)

# Continue training from a specific checkpoint
restore_path = os.path.join(continue_path, "checkpoint_5.pth")
command_continue = f"python tests/utils/train_mnist.py --restore_path {restore_path}"
run_cli(command_continue)
shutil.rmtree(continue_path)
5 changes: 4 additions & 1 deletion trainer/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ def save_best_model(
save_func=None,
**kwargs,
):
if current_loss < best_loss:
use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None
if (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or (
not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"]
):
best_model_name = f"best_model_{current_step}.pth"
checkpoint_path = os.path.join(out_path, best_model_name)
logger.info(" > BEST MODEL : %s", checkpoint_path)
Expand Down
15 changes: 11 additions & 4 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def __init__( # pylint: disable=dangerous-default-value
self.epochs_done = 0
self.restore_step = 0
self.restore_epoch = 0
self.best_loss = float("inf")
self.best_loss = {"train_loss": float("inf"), "eval_loss": float("inf") if self.config.run_eval else None}
self.train_loader = None
self.test_loader = None
self.eval_loader = None
Expand Down Expand Up @@ -1724,8 +1724,15 @@ def _restore_best_loss(self):
logger.info(" > Restoring best loss from %s ...", os.path.basename(self.args.best_path))
ch = load_fsspec(self.args.restore_path, map_location="cpu")
if "model_loss" in ch:
self.best_loss = ch["model_loss"]
logger.info(" > Starting with loaded last best loss %f", self.best_loss)
if isinstance(ch["model_loss"], dict):
self.best_loss = ch["model_loss"]
# For backwards-compatibility:
elif isinstance(ch["model_loss"], float):
if self.config.run_eval:
self.best_loss = {"train_loss": None, "eval_loss": ch["model_loss"]}
else:
self.best_loss = {"train_loss": ch["model_loss"], "eval_loss": None}
logger.info(" > Starting with loaded last best loss %s", self.best_loss)

def test(self, model=None, test_samples=None) -> None:
"""Run evaluation steps on the test data split. You can either provide the model and the test samples
Expand Down Expand Up @@ -1907,7 +1914,7 @@ def save_best_model(self) -> None:

# save the model and update the best_loss
self.best_loss = save_best_model(
eval_loss if eval_loss else train_loss,
{"train_loss": train_loss, "eval_loss": eval_loss},
self.best_loss,
self.config,
self.model,
Expand Down

0 comments on commit 39dce9e

Please sign in to comment.