Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 29, 2024
1 parent 3ad72ad commit da61ad5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@
.with_callback(
"checkpointer",
CheckpointerCallback(
save_interval=10_000,
ephemeral_save_interval=250,
save_interval=1000,
ephemeral_save_interval=50,
save_async=True,
pre_train_checkpoint=LOAD_PATH is None,
),
Expand Down
28 changes: 28 additions & 0 deletions src/olmo_core/train/callbacks/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
get_fs_local_rank,
is_distributed,
)
from olmo_core.exceptions import OLMoConfigurationError
from olmo_core.io import clear_directory

from ..checkpoint import Checkpointer
Expand All @@ -29,9 +30,25 @@ class CheckpointerCallback(Callback):
"""

save_interval: int = 250
"""
The interval, in steps, with which to save permanent checkoints.
"""

ephemeral_save_interval: Optional[int] = None
"""
The interval, in steps, with which to save temporary checkpoints. It's useful to set this to
a frequent interval for preemptible jobs.
"""

pre_train_checkpoint: Optional[bool] = None
"""
Save a pretrain checkpoint. Defaults to ``True`` unless the trainer resumes from a checkpoint.
"""

save_async: bool = False
"""
Save checkpoints asynchronously. Requires a backend that supports CPU.
"""

# Bookkeeping

Expand All @@ -42,6 +59,17 @@ class CheckpointerCallback(Callback):
_checkpoints: List[str] = field(default_factory=list)
_ephemeral_checkpoints: List[str] = field(default_factory=list)

def __post_init__(self):
if self.save_interval < 1:
raise OLMoConfigurationError("'save_interval' must be at least 1")
if self.ephemeral_save_interval is not None:
if self.ephemeral_save_interval < 1:
raise OLMoConfigurationError("'ephemeral_save_interval' must be at least 1")
if self.ephemeral_save_interval >= self.save_interval:
raise OLMoConfigurationError(
"'ephemeral_save_interval' must be less than 'save_interval'"
)

@property
def checkpointer(self) -> Checkpointer:
return self.trainer.checkpointer
Expand Down

0 comments on commit da61ad5

Please sign in to comment.