Skip to content

Commit

Permalink
Add to trainer and callback API (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh authored Aug 31, 2024
2 parents 6569131 + 4d263f9 commit 1d237dc
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 39 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Added

- Added `Trainer.save_checkpoint()` and `Trainer.save_checkpoint_async()` methods.
- Added `Callback.post_checkpoint_saved()` and `Callback.post_checkpoint_loaded()` methods.
- Added `ConfigSaverCallback`.

### Changed

- The `work_dir` argument to `TrainerConfig` now defaults to `save_folder` is `save_folder` is a local path, otherwise a temporary directory with the same name as the basename of the `save_folder`.
Expand Down
14 changes: 7 additions & 7 deletions src/examples/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
torchrun --nproc-per-node=4 src/examples/train.py run_name [OVERRIDES...]
"""

import json
import sys
from dataclasses import dataclass
from typing import List, cast

from olmo_core.config import Config, DType
from olmo_core.data import MemMapDatasetConfig, TokenizerConfig
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.utils import get_rank, init_hybrid_shard_mesh
from olmo_core.distributed.utils import init_hybrid_shard_mesh
from olmo_core.nn.transformer import TransformerConfig
from olmo_core.optim import AdamWConfig, CosWithWarmup
from olmo_core.train import (
Expand All @@ -24,6 +23,7 @@
)
from olmo_core.train.callbacks import (
CheckpointerCallback,
ConfigSaverCallback,
GPUMemoryMonitorCallback,
GradClipperCallback,
SchedulerCallback,
Expand Down Expand Up @@ -99,6 +99,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig:
enabled=False, # change to true to enable
),
)
.with_callback("config_saver", ConfigSaverCallback())
)

return ExperimentConfig(
Expand All @@ -122,11 +123,10 @@ def main(run_name: str, overrides: List[str]):
dataset = config.dataset.build()
trainer = config.trainer.build(model, optim, dataset)

# Save config to W&B and file.
if get_rank() == 0:
config_dict = config.as_config_dict()
cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict
trainer.write_file("config.json", json.dumps(config_dict, indent=2))
# Save config to W&B and each checkpoint dir.
config_dict = config.as_config_dict()
cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict
cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict

# Train.
trainer.fit()
Expand Down
15 changes: 15 additions & 0 deletions src/olmo_core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ def normalize_path(path: PathOrStr) -> str:
return str(path).rstrip("/").replace("file://", "")


def join_path(path1: PathOrStr, path2: PathOrStr) -> PathOrStr:
"""
Join two paths.
:param path1: The first path.
:param path2: The second path.
:returns: The joined result.
"""
if is_url(path1):
return f"{normalize_path(path1)}/{normalize_path(path2)}"
else:
return Path(path1) / path2


def resource_path(folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None) -> Path:
"""
Returns an actual path for local or remote file, potentially downloading it if a copy doesn't
Expand Down
2 changes: 2 additions & 0 deletions src/olmo_core/train/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .callback import Callback
from .checkpointer import CheckpointerCallback, CheckpointRemovalStrategy
from .config_saver import ConfigSaverCallback
from .console_logger import ConsoleLoggerCallback
from .garbage_collector import GarbageCollectorCallback
from .gpu_memory_monitor import GPUMemoryMonitorCallback
Expand All @@ -12,6 +13,7 @@
"Callback",
"CheckpointerCallback",
"CheckpointRemovalStrategy",
"ConfigSaverCallback",
"ConsoleLoggerCallback",
"GarbageCollectorCallback",
"GPUMemoryMonitorCallback",
Expand Down
18 changes: 18 additions & 0 deletions src/olmo_core/train/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Dict

from olmo_core.aliases import PathOrStr

if TYPE_CHECKING:
from ..trainer import Trainer

Expand Down Expand Up @@ -44,6 +46,14 @@ def step(self) -> int:
# def load_state_dict(self, state_dict: Dict[str, Any]):
# del state_dict

def post_checkpoint_loaded(self, path: PathOrStr):
"""
Called when a checkpoint is successfully loaded.
:param path: The path/URL to the checkpoint.
"""
del path

def pre_train(self):
"""
Runs before the training loop starts.
Expand Down Expand Up @@ -86,6 +96,14 @@ def post_step(self):
"""
pass

def post_checkpoint_saved(self, path: PathOrStr):
"""
Called when a checkpoint is successfully saved.
:param path: The path/URL to the checkpoint.
"""
del path

def log_metrics(self, step: int, metrics: Dict[str, float]):
"""
Called when metrics have been gathered for a given step (possibly a previous step).
Expand Down
22 changes: 3 additions & 19 deletions src/olmo_core/train/callbacks/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,34 +118,18 @@ def _await_last_checkpoint(self, blocking: bool = True) -> Optional[Future]:
if blocking or fut.done():
fut.result()
self._future = None
log.info(f"Checkpoint for step {self._latest_checkpoint:,d} saved successfully")
return fut
return None

def _save_checkpoint(self, save_async: Optional[bool] = None) -> str:
save_async = save_async if save_async is not None else self.save_async
self._await_last_checkpoint()
self._latest_checkpoint = self.step
dirname = self.checkpointer.checkpoint_dirname(self.step)
path = f"{self.save_folder}/{dirname}"
if save_async:
log.info(f"Saving checkpoint for step {self.step} to '{path}' asynchronously...")
self._future = self.checkpointer.save_async(
path,
self.trainer.model,
self.trainer.optim,
self.trainer.state_dict(),
)
path, self._future = self.trainer.save_checkpoint_async()
else:
log.info(f"Saving checkpoint for step {self.step} to '{path}'...")
self.checkpointer.save(
path,
self.trainer.model,
self.trainer.optim,
self.trainer.state_dict(),
)
log.info("Checkpoint saved")
return path
path = self.trainer.save_checkpoint()
return str(path)

def _remove_checkpoint(self, path: str):
if get_fs_local_rank() == 0:
Expand Down
32 changes: 32 additions & 0 deletions src/olmo_core/train/callbacks/config_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import json
import logging
from dataclasses import dataclass
from typing import Any, Dict, Optional

from olmo_core.aliases import PathOrStr
from olmo_core.distributed.utils import get_rank

from .callback import Callback

log = logging.getLogger(__name__)


@dataclass
class ConfigSaverCallback(Callback):
"""
A callback that writes an arbitrary JSON-serializable config dictionary to every checkpoint
directory written during training.
"""

config: Optional[Dict[str, Any]] = None
fname: str = "config.json"

def post_checkpoint_saved(self, path: PathOrStr):
if get_rank() != 0:
return

if self.config is None:
log.warning(f"Config not set on {self.__class__.__name__}, doing nothing")
return

self.trainer.write_file(self.fname, json.dumps(self.config), dir=path)
55 changes: 50 additions & 5 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import math
import signal
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union
Expand All @@ -29,7 +29,7 @@
scatter_object,
)
from ..exceptions import OLMoConfigurationError
from ..io import is_url, normalize_path
from ..io import is_url, join_path, normalize_path
from ..nn.functional.cross_entropy_loss import (
cross_entropy_loss,
fused_cross_entropy_loss,
Expand Down Expand Up @@ -626,6 +626,9 @@ def load_checkpoint(
assert trainer_state is not None
self.load_state_dict(trainer_state)

for callback in self.callbacks.values():
callback.post_checkpoint_loaded(dir)

self._checkpoint_loaded = True
log.info("Checkpoint successfully loaded")

Expand Down Expand Up @@ -654,6 +657,45 @@ def maybe_load_checkpoint(
log.warning(f"No checkpoint found in '{dir}', will train from scratch...")
return should_load

def save_checkpoint(self) -> PathOrStr:
"""
Save a checkpoint for the current step to the :data:`save_folder`.
:returns: The path/URL to the checkpoint.
"""
dirname = self.checkpointer.checkpoint_dirname(self.global_step)
path = join_path(self.save_folder, dirname)
log.info(f"Saving checkpoint for step {self.global_step} to '{path}'...")
self.checkpointer.save(path, self.model, self.optim, self.state_dict())
for callback in self.callbacks.values():
callback.post_checkpoint_saved(path)
log.info("Checkpoint saved")
return path

def save_checkpoint_async(self) -> Tuple[PathOrStr, Future]:
"""
Save a checkpoint for the current step to the :data:`save_folder` asynchronously.
:returns: The path/URL to the checkpoint and a future which will complete when the
checkpoint is successfully saved.
"""
step = self.global_step
dirname = self.checkpointer.checkpoint_dirname(step)
path = join_path(self.save_folder, dirname)

log.info(f"Saving checkpoint for step {step} to '{path}' asynchronously...")
fut = self.checkpointer.save_async(path, self.model, self.optim, self.state_dict())

def callback(future: Future):
future.result() # ensure it finished successfully
for callback in self.callbacks.values():
callback.post_checkpoint_saved(path)
log.info(f"Checkpoint for step {step} saved successfully")

fut.add_done_callback(callback)

return path, fut

def record_metric(
self, name: str, value: Union[float, torch.Tensor], reduce_type: Optional[ReduceType] = None
):
Expand All @@ -678,16 +720,19 @@ def record_metric(
self._metrics[self.global_step][name] = value
self._metrics_reduce_type[name] = reduce_type

def write_file(self, name: str, contents: Union[str, bytes]) -> PathOrStr:
def write_file(
self, name: str, contents: Union[str, bytes], dir: Optional[PathOrStr] = None
) -> PathOrStr:
"""
Write a file to the save folder.
Write a file to the :data:`save_folder` or ``dir``, if provided.
:param fname: The name of the file to write.
:param contents: The contents of the file to write.
:param dir: The path/URL to a directory to write the file to. Defaults to :data:`save_folder`.
:returns: The path/URL of the file.
"""
return self.checkpointer.write_file(self.save_folder, name, contents)
return self.checkpointer.write_file(dir or self.save_folder, name, contents)

def _duration_due(self, duration: Duration) -> bool:
if duration.unit == DurationUnit.steps:
Expand Down
15 changes: 7 additions & 8 deletions src/scripts/train/OLMo-7B.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Train a 7B OLMo model. Run this script without any arguments to see usage info.
"""

import json
import logging
import sys
from dataclasses import dataclass
Expand All @@ -13,7 +12,7 @@
from olmo_core.config import Config, DType, StrEnum
from olmo_core.data import DataMix, MemMapDatasetConfig, TokenizerConfig
from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType
from olmo_core.distributed.utils import get_num_nodes, get_rank, init_hybrid_shard_mesh
from olmo_core.distributed.utils import get_num_nodes, init_hybrid_shard_mesh
from olmo_core.io import is_url
from olmo_core.launch.beaker import (
BeakerEnvSecret,
Expand All @@ -30,6 +29,7 @@
)
from olmo_core.train.callbacks import (
CheckpointerCallback,
ConfigSaverCallback,
GPUMemoryMonitorCallback,
GradClipperCallback,
SchedulerCallback,
Expand Down Expand Up @@ -174,6 +174,7 @@ def build_config(run_name: str, cluster: str, overrides: List[str]) -> Experimen
cancel_check_interval=10,
),
)
.with_callback("config_saver", ConfigSaverCallback())
)

return ExperimentConfig(
Expand Down Expand Up @@ -205,12 +206,10 @@ def train(config: ExperimentConfig):
dataset = config.dataset.build()
trainer = config.trainer.build(model, optim, dataset)

# Record the config to W&B and to the save folder.
if get_rank() == 0:
config_dict = config.as_config_dict()
wandb_callback = cast(WandBCallback, trainer.callbacks["wandb"])
wandb_callback.config = config_dict
trainer.write_file("config.json", json.dumps(config_dict, indent=2))
# Record the config to W&B and each checkpoint dir.
config_dict = config.as_config_dict()
cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict
cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict

# Train.
trainer.fit()
Expand Down

0 comments on commit 1d237dc

Please sign in to comment.