Skip to content

Commit

Permalink
change to is_temporary, fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 18, 2025
1 parent 6c07ebd commit 3eca780
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
latest_checkpoint = discover_latest_checkpoint(self.base_path)
if latest_checkpoint is not None and delete_old_temp_checkpoints:
metadata = _load_metadata(latest_checkpoint)
if metadata.get("is_permanent", True) is False:
if metadata.get("is_temporary", False):
logger.info(
f"Found prior temporary checkpoint {latest_checkpoint}. We will delete it after"
" saving a new checkpoint."
Expand Down Expand Up @@ -221,7 +221,7 @@ def callback():
if last_checkpoint is not None:
self._rm_checkpoint(last_checkpoint)

self.save_checkpoint(info, destination, commit_callback=callback, is_permanent=save_permanent_ckpt)
self.save_checkpoint(info, destination, commit_callback=callback, is_temporary=not save_permanent_ckpt)

def _get_current_step_save_interval(self, step):
# binary search for the correct interval
Expand Down Expand Up @@ -256,7 +256,12 @@ def _do_rm_checkpoint(self, checkpoint):
logger.exception(f"Failed to delete checkpoint {checkpoint}", exc_info=True)

def save_checkpoint(
self, info, destination: str, commit_callback: Optional[Callable[[], None]] = None, *, is_permanent: bool
self,
info,
destination: str,
commit_callback: Optional[Callable[[], None]] = None,
*,
is_temporary: bool = False,
):
path = os.path.join(self.base_path, destination)
logger.info(f"Saving checkpoint at step {info.step} to {path}")
Expand All @@ -268,7 +273,7 @@ def save_checkpoint(
checkpoint_path=path,
manager=self._manager,
commit_callback=commit_callback,
is_permanent=is_permanent,
is_temporary=is_temporary,
)
self._last_save_step = info.step
self._last_save_time = self._dt_now_injection()
Expand Down Expand Up @@ -314,7 +319,7 @@ def __call__(self, step_info):
self.checkpointer.save_checkpoint(
step_info,
f"epoch-{current_epoch}",
is_permanent=True,
is_temporary=True,
)
self._last_saved_epoch = current_epoch

Expand All @@ -326,7 +331,7 @@ def save_checkpoint(
manager: Optional[GlobalAsyncCheckpointManager] = None,
*,
commit_callback: Optional[Callable[[], None]] = None,
is_permanent: bool = True,
is_temporary: bool = True,
):
"""
Save a checkpoint to a given path using TensorStore.
Expand All @@ -336,6 +341,14 @@ def save_checkpoint(
If training_state is None, no training state will be saved.
This method is jax.Array-aware and will save shards in a way that can be restored
Args:
tree: the PyTree to save
step: the step to save the checkpoint at
checkpoint_path: the path to save the checkpoint to
manager: the GlobalAsyncCheckpointManager to use for saving the checkpoint
commit_callback: a callback to call after the checkpoint has been saved
is_temporary: whether the checkpoint is temporary
"""
step = int(step)
checkpoint_path = str(checkpoint_path)
Expand All @@ -346,7 +359,7 @@ def save_checkpoint(
fs.makedirs(plain_path, exist_ok=True)

def my_callback():
_save_metadata(checkpoint_path, fs, step, is_permanent)
_save_metadata(checkpoint_path, fs, step, is_temporary)
logger.info(f"Saved checkpoint to {checkpoint_path} for step {step}")

if commit_callback is not None:
Expand All @@ -359,8 +372,8 @@ def my_callback():
return checkpoint_path


def _save_metadata(checkpoint_path, fs, step, is_permanent):
metadata = {"step": step, "timestamp": datetime.datetime.now().isoformat(), "is_permanent": is_permanent}
def _save_metadata(checkpoint_path, fs, step, is_temporary):
metadata = {"step": step, "timestamp": datetime.datetime.now().isoformat(), "is_temporary": is_temporary}
if jax.process_index() == 0:
with fs.open(os.path.join(checkpoint_path, "metadata.json"), "w") as json_out:
json.dump(metadata, json_out)
Expand Down

0 comments on commit 3eca780

Please sign in to comment.