Skip to content

Commit

Permalink
clean up old temp checkpoints after preemptions
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Feb 18, 2025
1 parent 748e198 commit 6c07ebd
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 9 deletions.
41 changes: 34 additions & 7 deletions src/levanter/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
*,
keep_params: PyTree[FilterSpec] = True,
dt_now_injection: Optional[Callable[[], datetime.datetime]] = None,
delete_old_temp_checkpoints: bool = True,
):
"""
Class for managing checkpoints. Saves checkpoints according to two policies: time and step.
Expand All @@ -83,6 +84,7 @@ def __init__(
step_policies: the step policies to use
keep_params: a PyTree of FilterSpecs that specifies which parameters to keep in the checkpoint
dt_now_injection: a function that returns the current time. useful for testing
delete_old_temp_checkpoints: if True, delete old checkpoints when saving a new one
"""
self.base_path = str(base_path)
self.save_interval = save_interval
Expand All @@ -91,7 +93,6 @@ def __init__(
self._dt_now_injection = dt_now_injection or datetime.datetime.now
self._last_save_time = self._dt_now_injection()
self._last_save_step = 0
self._last_temporary_checkpoint = None

# ensure that the step_policies are sorted. We could sort, but instead we'll just insist that they are sorted
# since it's probably a typo if they aren't
Expand All @@ -117,6 +118,18 @@ def __init__(
self._async_checkpoint_remover_thread.start()
self._checkpoint_being_removed = None

# discover latest checkpoint and see if it's temporary
self._last_temporary_checkpoint = None
latest_checkpoint = discover_latest_checkpoint(self.base_path)
if latest_checkpoint is not None and delete_old_temp_checkpoints:
metadata = _load_metadata(latest_checkpoint)
if metadata.get("is_permanent", True) is False:
logger.info(
f"Found prior temporary checkpoint {latest_checkpoint}. We will delete it after"
" saving a new checkpoint."
)
self._last_temporary_checkpoint = latest_checkpoint

def load_checkpoint(
self,
state: M,
Expand Down Expand Up @@ -185,6 +198,8 @@ def on_step(self, info, force: bool = False):
should_save, save_permanent_ckpt = broadcast_one_to_all(
jnp.array([my_should_save, my_save_permanent_ckpt], dtype=jnp.bool_)
)
# this comes out as np.bool_, so we need to convert it to a regular bool so json serialization works
save_permanent_ckpt = bool(save_permanent_ckpt)

# log the decision
if should_save:
Expand All @@ -206,7 +221,7 @@ def callback():
if last_checkpoint is not None:
self._rm_checkpoint(last_checkpoint)

self.save_checkpoint(info, destination, commit_callback=callback)
self.save_checkpoint(info, destination, commit_callback=callback, is_permanent=save_permanent_ckpt)

def _get_current_step_save_interval(self, step):
# binary search for the correct interval
Expand Down Expand Up @@ -240,7 +255,9 @@ def _do_rm_checkpoint(self, checkpoint):
except Exception: # pylint: disable=broad-except
logger.exception(f"Failed to delete checkpoint {checkpoint}", exc_info=True)

def save_checkpoint(self, info, destination: str, commit_callback: Optional[Callable[[], None]] = None):
def save_checkpoint(
self, info, destination: str, commit_callback: Optional[Callable[[], None]] = None, *, is_permanent: bool
):
path = os.path.join(self.base_path, destination)
logger.info(f"Saving checkpoint at step {info.step} to {path}")
state = info.state.saveable_state
Expand All @@ -251,6 +268,7 @@ def save_checkpoint(self, info, destination: str, commit_callback: Optional[Call
checkpoint_path=path,
manager=self._manager,
commit_callback=commit_callback,
is_permanent=is_permanent,
)
self._last_save_step = info.step
self._last_save_time = self._dt_now_injection()
Expand Down Expand Up @@ -296,6 +314,7 @@ def __call__(self, step_info):
self.checkpointer.save_checkpoint(
step_info,
f"epoch-{current_epoch}",
is_permanent=True,
)
self._last_saved_epoch = current_epoch

Expand All @@ -307,6 +326,7 @@ def save_checkpoint(
manager: Optional[GlobalAsyncCheckpointManager] = None,
*,
commit_callback: Optional[Callable[[], None]] = None,
is_permanent: bool = True,
):
"""
Save a checkpoint to a given path using TensorStore.
Expand All @@ -326,7 +346,7 @@ def save_checkpoint(
fs.makedirs(plain_path, exist_ok=True)

def my_callback():
save_metadata(checkpoint_path, fs, step)
_save_metadata(checkpoint_path, fs, step, is_permanent)
logger.info(f"Saved checkpoint to {checkpoint_path} for step {step}")

if commit_callback is not None:
Expand All @@ -339,8 +359,8 @@ def my_callback():
return checkpoint_path


def save_metadata(checkpoint_path, fs, step):
metadata = {"step": step, "timestamp": datetime.datetime.now().isoformat()}
def _save_metadata(checkpoint_path, fs, step, is_permanent):
metadata = {"step": step, "timestamp": datetime.datetime.now().isoformat(), "is_permanent": is_permanent}
if jax.process_index() == 0:
with fs.open(os.path.join(checkpoint_path, "metadata.json"), "w") as json_out:
json.dump(metadata, json_out)
Expand Down Expand Up @@ -506,7 +526,7 @@ def load_or_init(*args, **kwargs):
return load_or_init


def load_metadata(checkpoint_path, fs=None):
def _load_metadata(checkpoint_path, fs=None):
if fs is None:
fs, _, _ = fsspec.get_fs_token_paths(str(checkpoint_path))
with fs.open(os.path.join(checkpoint_path, "metadata.json")) as metadata_in:
Expand Down Expand Up @@ -568,6 +588,12 @@ class CheckpointerConfig:
) # list of dicts with two keys: every and until

append_run_id_to_base_path: bool = True
delete_old_temp_checkpoints: bool = True
"""
If True, delete old checkpoints from prior attempts at this run. If False, keep them.
This is useful if the run is being preempted and restarted, and you want to keep the old checkpoints.
"""

def expanded_path(self, run_id) -> str:
if self.append_run_id_to_base_path:
Expand All @@ -580,6 +606,7 @@ def create(self, run_id) -> Checkpointer:
base_path=self.expanded_path(run_id),
save_interval=self.save_interval,
step_policies=keeps,
delete_old_temp_checkpoints=self.delete_old_temp_checkpoints,
)

def __post_init__(self):
Expand Down
70 changes: 68 additions & 2 deletions tests/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from levanter.checkpoint import (
Checkpointer,
CheckpointInterval,
_load_metadata,
discover_latest_checkpoint,
load_checkpoint,
load_checkpoint_or_initialize,
load_metadata,
save_checkpoint,
)
from levanter.trainer_state import TrainerState
Expand All @@ -49,7 +49,7 @@ def _dummy_step_info(step):

def _get_checkpoint_steps(checkpoint_dir):
paths = list(pathlib.Path(checkpoint_dir).iterdir())
return sorted([load_metadata(f)["step"] for f in paths])
return sorted([_load_metadata(f)["step"] for f in paths])


def test_checkpointer_changing_policy():
Expand Down Expand Up @@ -247,6 +247,72 @@ def test_checkpoint_discovery():
assert discover_latest_checkpoint("file:///tmp/does-not-exist") is None


def test_checkpointer_deletes_previous_checkpoints():
fake_now = datetime.datetime(2021, 1, 1, 0, 0, 0)

tick = 10

def advance_time(delta_seconds):
nonlocal fake_now
fake_now += timedelta(seconds=delta_seconds)

with tempfile.TemporaryDirectory() as tmpdir:
checkpointer = Checkpointer(
tmpdir,
timedelta(seconds=tick),
[
CheckpointInterval(every=5, until=20),
CheckpointInterval(every=10, until=None),
],
dt_now_injection=lambda: fake_now,
)

checkpointer.on_step(_dummy_step_info(0))
advance_time(tick)
for i in range(1, 6):
checkpointer.on_step(_dummy_step_info(i))
checkpointer.wait_until_finished()
assert _get_checkpoint_steps(tmpdir) == [5]
advance_time(tick)
checkpointer.on_step(_dummy_step_info(6))
checkpointer.wait_until_finished()
assert _get_checkpoint_steps(tmpdir) == [5, 6]

# now make a new one and ensure it deletes the old one
checkpointer = Checkpointer(
tmpdir,
timedelta(seconds=tick),
[
CheckpointInterval(every=5, until=20),
CheckpointInterval(every=10, until=None),
],
dt_now_injection=lambda: fake_now,
)

checkpointer.on_step(_dummy_step_info(7))
advance_time(tick)
checkpointer.on_step(_dummy_step_info(8))
checkpointer.wait_until_finished()
assert _get_checkpoint_steps(tmpdir) == [5, 8]

# now make sure if we don't enable deleting old checkpoints, it doesn't delete them
checkpointer = Checkpointer(
tmpdir,
timedelta(seconds=tick),
[
CheckpointInterval(every=20, until=None),
],
dt_now_injection=lambda: fake_now,
delete_old_temp_checkpoints=False,
)

checkpointer.on_step(_dummy_step_info(9))
advance_time(tick)
checkpointer.on_step(_dummy_step_info(10))
checkpointer.wait_until_finished()
assert _get_checkpoint_steps(tmpdir) == [5, 8, 10]


def test_load_from_checkpoint_or_initialize():
In = Axis("in", 2)
Out = Axis("out", 1)
Expand Down

0 comments on commit 6c07ebd

Please sign in to comment.