Skip to content

Commit

Permalink
reduce number of S3 requests needed
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Aug 28, 2024
1 parent 86c75d5 commit 1744420
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
16 changes: 8 additions & 8 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,18 +512,18 @@ def load_checkpoint(
"""
Load a checkpoint.
:param dir: The path/URL to the checkpoint.
:param dir: The path/URL to a checkpoint or a folder of checkpoints.
:param load_optimizer_state: Load optimizer state.
:param load_trainer_state: Load trainer state.
"""
if not self.checkpointer.dir_is_checkpoint(dir):
dir = normalize_path(dir)

# NOTE: to avoid making a ton of client requests (S3 or otherwise) we only make those
# requests from rank 0 then scatter the result to the other ranks.
if get_rank() == 0 and not self.checkpointer.dir_is_checkpoint(dir):
# Try to find the latest checkpoint in the directory.
latest_checkpoint: Optional[str] = None
if get_rank() == 0:
latest_checkpoint = self.checkpointer.latest_checkpoint(dir)
latest_checkpoint = scatter_object(latest_checkpoint)
assert latest_checkpoint is not None
dir = latest_checkpoint
dir = self.checkpointer.latest_checkpoint(dir)
dir = scatter_object(dir)

log.info(f"Loading checkpoint from '{dir}'...")
trainer_state = self.checkpointer.load(
Expand Down
28 changes: 18 additions & 10 deletions src/scripts/train/OLMo-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from olmo_core.config import Config, DType, StrEnum
from olmo_core.data import DataMix, MemMapDatasetConfig, TokenizerConfig
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.utils import get_num_nodes, get_rank, init_hybrid_shard_mesh
from olmo_core.distributed.utils import (
get_num_nodes,
get_rank,
init_hybrid_shard_mesh,
scatter_object,
)
from olmo_core.io import is_url
from olmo_core.launch.beaker import (
BeakerEnvSecret,
Expand Down Expand Up @@ -221,15 +226,18 @@ def train(config: ExperimentConfig):
dataset = config.dataset.build()
trainer = config.trainer.build(model, optim, dataset)

# Maybe load a checkpoint.
if (load_path := config.load_path) is not None and (
config.load_strategy == LoadStrategy.always
or (
config.load_strategy == LoadStrategy.if_available
and trainer.checkpointer.contains_checkpoint(load_path)
)
):
trainer.load_checkpoint(load_path)
if (load_path := config.load_path) is not None:
# Maybe load a checkpoint.
should_load: bool = True
if config.load_strategy == LoadStrategy.never:
should_load = False
elif config.load_strategy == LoadStrategy.if_available:
if get_rank() == 0:
should_load = trainer.checkpointer.contains_checkpoint(load_path)
should_load = scatter_object(should_load)

if should_load:
trainer.load_checkpoint(load_path)
elif get_rank() == 0:
# Save config to file.
trainer.write_file("config.json", json.dumps(config_dict, indent=2))
Expand Down
6 changes: 4 additions & 2 deletions src/test/train/checkpoint_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from ..distributed.utils import run_distributed_test


def run_checkpointer(dir, model_factory):
dir = normalize_path(dir)
def run_checkpointer(base_dir, model_factory):
dir = f"{normalize_path(base_dir)}/{Checkpointer.checkpoint_dirname(10)}"

if not is_url(dir):
os.environ["OLMO_SHARED_FS"] = "1"
Expand All @@ -29,6 +29,8 @@ def run_checkpointer(dir, model_factory):
assert file_exists((f"{dir}/train/rank1.pt"))
assert not dir_is_empty((f"{dir}/model_and_optim"))
assert checkpointer.dir_is_checkpoint(dir)
assert list(checkpointer.find_checkpoints(base_dir)) == [(10, dir)]
assert checkpointer.latest_checkpoint(base_dir) == dir

# Load checkpoint.
train_state = checkpointer.load(dir, model, optim)
Expand Down

0 comments on commit 1744420

Please sign in to comment.