From 3eca780cd2904a81077266680528c9271ea0e56b Mon Sep 17 00:00:00 2001 From: David Hall Date: Tue, 18 Feb 2025 10:59:29 -0800 Subject: [PATCH] change to is_temporary, fix tests --- src/levanter/checkpoint.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/levanter/checkpoint.py b/src/levanter/checkpoint.py index 658144fd8..1bacde09e 100644 --- a/src/levanter/checkpoint.py +++ b/src/levanter/checkpoint.py @@ -123,7 +123,7 @@ def __init__( latest_checkpoint = discover_latest_checkpoint(self.base_path) if latest_checkpoint is not None and delete_old_temp_checkpoints: metadata = _load_metadata(latest_checkpoint) - if metadata.get("is_permanent", True) is False: + if metadata.get("is_temporary", False): logger.info( f"Found prior temporary checkpoint {latest_checkpoint}. We will delete it after" " saving a new checkpoint." @@ -221,7 +221,7 @@ def callback(): if last_checkpoint is not None: self._rm_checkpoint(last_checkpoint) - self.save_checkpoint(info, destination, commit_callback=callback, is_permanent=save_permanent_ckpt) + self.save_checkpoint(info, destination, commit_callback=callback, is_temporary=not save_permanent_ckpt) def _get_current_step_save_interval(self, step): # binary search for the correct interval @@ -256,7 +256,12 @@ def _do_rm_checkpoint(self, checkpoint): logger.exception(f"Failed to delete checkpoint {checkpoint}", exc_info=True) def save_checkpoint( - self, info, destination: str, commit_callback: Optional[Callable[[], None]] = None, *, is_permanent: bool + self, + info, + destination: str, + commit_callback: Optional[Callable[[], None]] = None, + *, + is_temporary: bool = False, ): path = os.path.join(self.base_path, destination) logger.info(f"Saving checkpoint at step {info.step} to {path}") @@ -268,7 +273,7 @@ def save_checkpoint( checkpoint_path=path, manager=self._manager, commit_callback=commit_callback, - is_permanent=is_permanent, + is_temporary=is_temporary, ) self._last_save_step = info.step self._last_save_time = self._dt_now_injection() @@ -314,7 +319,7 @@ def __call__(self, step_info): self.checkpointer.save_checkpoint( step_info, f"epoch-{current_epoch}", - is_permanent=True, + is_temporary=True, ) self._last_saved_epoch = current_epoch @@ -326,7 +331,7 @@ def save_checkpoint( manager: Optional[GlobalAsyncCheckpointManager] = None, *, commit_callback: Optional[Callable[[], None]] = None, - is_permanent: bool = True, + is_temporary: bool = True, ): """ Save a checkpoint to a given path using TensorStore. @@ -336,6 +341,14 @@ def save_checkpoint( If training_state is None, no training state will be saved. This method is jax.Array-aware and will save shards in a way that can be restored + + Args: + tree: the PyTree to save + step: the step to save the checkpoint at + checkpoint_path: the path to save the checkpoint to + manager: the GlobalAsyncCheckpointManager to use for saving the checkpoint + commit_callback: a callback to call after the checkpoint has been saved + is_temporary: whether the checkpoint is temporary """ step = int(step) checkpoint_path = str(checkpoint_path) @@ -346,7 +359,7 @@ def save_checkpoint( fs.makedirs(plain_path, exist_ok=True) def my_callback(): - _save_metadata(checkpoint_path, fs, step, is_permanent) + _save_metadata(checkpoint_path, fs, step, is_temporary) logger.info(f"Saved checkpoint to {checkpoint_path} for step {step}") if commit_callback is not None: @@ -359,8 +372,8 @@ def my_callback(): return checkpoint_path -def _save_metadata(checkpoint_path, fs, step, is_permanent): - metadata = {"step": step, "timestamp": datetime.datetime.now().isoformat(), "is_permanent": is_permanent} +def _save_metadata(checkpoint_path, fs, step, is_temporary): + metadata = {"step": step, "timestamp": datetime.datetime.now().isoformat(), "is_temporary": is_temporary} if jax.process_index() == 0: with fs.open(os.path.join(checkpoint_path, "metadata.json"), "w") as json_out: json.dump(metadata, json_out)