From eb97945e4c09a09775521e54dcc8328596dc7e6e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 31 Aug 2024 11:19:17 -0700 Subject: [PATCH 1/3] Add `Trainer.save_checkpoint(_async)` methods --- CHANGELOG.md | 4 ++ src/olmo_core/io.py | 15 +++++++ src/olmo_core/train/callbacks/checkpointer.py | 22 ++-------- src/olmo_core/train/trainer.py | 44 +++++++++++++++++-- 4 files changed, 63 insertions(+), 22 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ab30f0e..39cbb355 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added `Trainer.save_checkpoint()` and `Trainer.save_checkpoint_async()` methods. + ### Changed - The `work_dir` argument to `TrainerConfig` now defaults to `save_folder` is `save_folder` is a local path, otherwise a temporary directory with the same name as the basename of the `save_folder`. diff --git a/src/olmo_core/io.py b/src/olmo_core/io.py index a4736740..d5dcabc4 100644 --- a/src/olmo_core/io.py +++ b/src/olmo_core/io.py @@ -36,6 +36,21 @@ def normalize_path(path: PathOrStr) -> str: return str(path).rstrip("/").replace("file://", "") +def join_path(path1: PathOrStr, path2: PathOrStr) -> PathOrStr: + """ + Join two paths. + + :param path1: The first path. + :param path2: The second path. + + :returns: The joined result. + """ + if is_url(path1): + return f"{normalize_path(path1)}/{normalize_path(path2)}" + else: + return Path(path1) / path2 + + def resource_path(folder: PathOrStr, fname: str, local_cache: Optional[PathOrStr] = None) -> Path: """ Returns an actual path for local or remote file, potentially downloading it if a copy doesn't diff --git a/src/olmo_core/train/callbacks/checkpointer.py b/src/olmo_core/train/callbacks/checkpointer.py index 089441ac..6a88f9d5 100644 --- a/src/olmo_core/train/callbacks/checkpointer.py +++ b/src/olmo_core/train/callbacks/checkpointer.py @@ -118,7 +118,6 @@ def _await_last_checkpoint(self, blocking: bool = True) -> Optional[Future]: if blocking or fut.done(): fut.result() self._future = None - log.info(f"Checkpoint for step {self._latest_checkpoint:,d} saved successfully") return fut return None @@ -126,26 +125,11 @@ def _save_checkpoint(self, save_async: Optional[bool] = None) -> str: save_async = save_async if save_async is not None else self.save_async self._await_last_checkpoint() self._latest_checkpoint = self.step - dirname = self.checkpointer.checkpoint_dirname(self.step) - path = f"{self.save_folder}/{dirname}" if save_async: - log.info(f"Saving checkpoint for step {self.step} to '{path}' asynchronously...") - self._future = self.checkpointer.save_async( - path, - self.trainer.model, - self.trainer.optim, - self.trainer.state_dict(), - ) + path, self._future = self.trainer.save_checkpoint_async() else: - log.info(f"Saving checkpoint for step {self.step} to '{path}'...") - self.checkpointer.save( - path, - self.trainer.model, - self.trainer.optim, - self.trainer.state_dict(), - ) - log.info("Checkpoint saved") - return path + path = self.trainer.save_checkpoint() + return str(path) def _remove_checkpoint(self, path: str): if get_fs_local_rank() == 0: diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 9a5e1ba1..9e360474 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -3,7 +3,7 @@ import math import signal from collections import OrderedDict -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Generator, List, Literal, Optional, Tuple, Union @@ -29,7 +29,7 @@ scatter_object, ) from ..exceptions import OLMoConfigurationError -from ..io import is_url, normalize_path +from ..io import is_url, join_path, normalize_path from ..nn.functional.cross_entropy_loss import ( cross_entropy_loss, fused_cross_entropy_loss, @@ -654,6 +654,44 @@ def maybe_load_checkpoint( log.warning(f"No checkpoint found in '{dir}', will train from scratch...") return should_load + def save_checkpoint(self) -> PathOrStr: + """ + Save a checkpoint for the current step to the :data:`save_folder`. + + :returns: The path/URL to the checkpoint. + """ + dirname = self.checkpointer.checkpoint_dirname(self.global_step) + path = join_path(self.save_folder, dirname) + log.info(f"Saving checkpoint for step {self.global_step} to '{path}'...") + self.checkpointer.save(path, self.model, self.optim, self.state_dict()) + log.info("Checkpoint saved") + return path + + def save_checkpoint_async(self) -> Tuple[PathOrStr, Future]: + """ + Save a checkpoint for the current step to the :data:`save_folder` asynchronously. + + :param done_callback: A function that will be called with the path to the checkpoint + after it's saved successfully. + + :returns: The path/URL to the checkpoint and a future which will complete when the + checkpoint is successfully saved. + """ + step = self.global_step + dirname = self.checkpointer.checkpoint_dirname(step) + path = join_path(self.save_folder, dirname) + + log.info(f"Saving checkpoint for step {step} to '{path}' asynchronously...") + fut = self.checkpointer.save_async(path, self.model, self.optim, self.state_dict()) + + def callback(future: Future): + future.result() # ensure it finished successfully + log.info(f"Checkpoint for step {step} saved successfully") + + fut.add_done_callback(callback) + + return path, fut + def record_metric( self, name: str, value: Union[float, torch.Tensor], reduce_type: Optional[ReduceType] = None ): @@ -680,7 +718,7 @@ def record_metric( def write_file(self, name: str, contents: Union[str, bytes]) -> PathOrStr: """ - Write a file to the save folder. + Write a file to the :data:`save_folder`. :param fname: The name of the file to write. :param contents: The contents of the file to write. From cba77dac35a1efd0e963c077e2b82788d51b28b3 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 31 Aug 2024 11:26:59 -0700 Subject: [PATCH 2/3] Add `Callback.post_checkpoint_(saved|loaded)` methods --- CHANGELOG.md | 1 + src/olmo_core/train/callbacks/callback.py | 16 ++++++++++++++++ src/olmo_core/train/trainer.py | 7 +++++++ 3 files changed, 24 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 39cbb355..cd7f349a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `Trainer.save_checkpoint()` and `Trainer.save_checkpoint_async()` methods. +- Added `Callback.post_checkpoint_saved()` and `Callback.post_checkpoint_loaded()` methods. ### Changed diff --git a/src/olmo_core/train/callbacks/callback.py b/src/olmo_core/train/callbacks/callback.py index 7e47adf3..c3350da2 100644 --- a/src/olmo_core/train/callbacks/callback.py +++ b/src/olmo_core/train/callbacks/callback.py @@ -1,6 +1,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Dict +from olmo_core.aliases import PathOrStr + if TYPE_CHECKING: from ..trainer import Trainer @@ -44,6 +46,12 @@ def step(self) -> int: # def load_state_dict(self, state_dict: Dict[str, Any]): # del state_dict + def post_checkpoint_loaded(self, path: PathOrStr): + """ + Called when a checkpoint is successfully loaded. + """ + del path + def pre_train(self): """ Runs before the training loop starts. @@ -86,6 +94,14 @@ def post_step(self): """ pass + def post_checkpoint_saved(self, path: PathOrStr): + """ + Called when a checkpoint is successfully saved. + + :param path: The path/URL to the checkpoint. + """ + del path + def log_metrics(self, step: int, metrics: Dict[str, float]): """ Called when metrics have been gathered for a given step (possibly a previous step). diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 9e360474..8ec0e021 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -626,6 +626,9 @@ def load_checkpoint( assert trainer_state is not None self.load_state_dict(trainer_state) + for callback in self.callbacks.values(): + callback.post_checkpoint_loaded(dir) + self._checkpoint_loaded = True log.info("Checkpoint successfully loaded") @@ -664,6 +667,8 @@ def save_checkpoint(self) -> PathOrStr: path = join_path(self.save_folder, dirname) log.info(f"Saving checkpoint for step {self.global_step} to '{path}'...") self.checkpointer.save(path, self.model, self.optim, self.state_dict()) + for callback in self.callbacks.values(): + callback.post_checkpoint_saved(path) log.info("Checkpoint saved") return path @@ -686,6 +691,8 @@ def save_checkpoint_async(self) -> Tuple[PathOrStr, Future]: def callback(future: Future): future.result() # ensure it finished successfully + for callback in self.callbacks.values(): + callback.post_checkpoint_saved(path) log.info(f"Checkpoint for step {step} saved successfully") fut.add_done_callback(callback) From 4d263f9e7f6bf6f60e6e078690313b487fed7007 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sat, 31 Aug 2024 11:51:15 -0700 Subject: [PATCH 3/3] Add `ConfigSaverCallback` --- CHANGELOG.md | 1 + src/examples/train.py | 14 ++++---- src/olmo_core/train/callbacks/__init__.py | 2 ++ src/olmo_core/train/callbacks/callback.py | 2 ++ src/olmo_core/train/callbacks/config_saver.py | 32 +++++++++++++++++++ src/olmo_core/train/trainer.py | 14 ++++---- src/scripts/train/OLMo-7B.py | 15 ++++----- 7 files changed, 58 insertions(+), 22 deletions(-) create mode 100644 src/olmo_core/train/callbacks/config_saver.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cd7f349a..40faaf15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `Trainer.save_checkpoint()` and `Trainer.save_checkpoint_async()` methods. - Added `Callback.post_checkpoint_saved()` and `Callback.post_checkpoint_loaded()` methods. +- Added `ConfigSaverCallback`. ### Changed diff --git a/src/examples/train.py b/src/examples/train.py index 16b10e26..9be95b30 100644 --- a/src/examples/train.py +++ b/src/examples/train.py @@ -6,7 +6,6 @@ torchrun --nproc-per-node=4 src/examples/train.py run_name [OVERRIDES...] """ -import json import sys from dataclasses import dataclass from typing import List, cast @@ -14,7 +13,7 @@ from olmo_core.config import Config, DType from olmo_core.data import MemMapDatasetConfig, TokenizerConfig from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType -from olmo_core.distributed.utils import get_rank, init_hybrid_shard_mesh +from olmo_core.distributed.utils import init_hybrid_shard_mesh from olmo_core.nn.transformer import TransformerConfig from olmo_core.optim import AdamWConfig, CosWithWarmup from olmo_core.train import ( @@ -24,6 +23,7 @@ ) from olmo_core.train.callbacks import ( CheckpointerCallback, + ConfigSaverCallback, GPUMemoryMonitorCallback, GradClipperCallback, SchedulerCallback, @@ -99,6 +99,7 @@ def build_config(run_name: str, overrides: List[str]) -> ExperimentConfig: enabled=False, # change to true to enable ), ) + .with_callback("config_saver", ConfigSaverCallback()) ) return ExperimentConfig( @@ -122,11 +123,10 @@ def main(run_name: str, overrides: List[str]): dataset = config.dataset.build() trainer = config.trainer.build(model, optim, dataset) - # Save config to W&B and file. - if get_rank() == 0: - config_dict = config.as_config_dict() - cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict - trainer.write_file("config.json", json.dumps(config_dict, indent=2)) + # Save config to W&B and each checkpoint dir. + config_dict = config.as_config_dict() + cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict + cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict # Train. trainer.fit() diff --git a/src/olmo_core/train/callbacks/__init__.py b/src/olmo_core/train/callbacks/__init__.py index 5feaaa0a..5c295642 100644 --- a/src/olmo_core/train/callbacks/__init__.py +++ b/src/olmo_core/train/callbacks/__init__.py @@ -1,5 +1,6 @@ from .callback import Callback from .checkpointer import CheckpointerCallback, CheckpointRemovalStrategy +from .config_saver import ConfigSaverCallback from .console_logger import ConsoleLoggerCallback from .garbage_collector import GarbageCollectorCallback from .gpu_memory_monitor import GPUMemoryMonitorCallback @@ -12,6 +13,7 @@ "Callback", "CheckpointerCallback", "CheckpointRemovalStrategy", + "ConfigSaverCallback", "ConsoleLoggerCallback", "GarbageCollectorCallback", "GPUMemoryMonitorCallback", diff --git a/src/olmo_core/train/callbacks/callback.py b/src/olmo_core/train/callbacks/callback.py index c3350da2..8a345a46 100644 --- a/src/olmo_core/train/callbacks/callback.py +++ b/src/olmo_core/train/callbacks/callback.py @@ -49,6 +49,8 @@ def step(self) -> int: def post_checkpoint_loaded(self, path: PathOrStr): """ Called when a checkpoint is successfully loaded. + + :param path: The path/URL to the checkpoint. """ del path diff --git a/src/olmo_core/train/callbacks/config_saver.py b/src/olmo_core/train/callbacks/config_saver.py new file mode 100644 index 00000000..423ea1f0 --- /dev/null +++ b/src/olmo_core/train/callbacks/config_saver.py @@ -0,0 +1,32 @@ +import json +import logging +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from olmo_core.aliases import PathOrStr +from olmo_core.distributed.utils import get_rank + +from .callback import Callback + +log = logging.getLogger(__name__) + + +@dataclass +class ConfigSaverCallback(Callback): + """ + A callback that writes an arbitrary JSON-serializable config dictionary to every checkpoint + directory written during training. + """ + + config: Optional[Dict[str, Any]] = None + fname: str = "config.json" + + def post_checkpoint_saved(self, path: PathOrStr): + if get_rank() != 0: + return + + if self.config is None: + log.warning(f"Config not set on {self.__class__.__name__}, doing nothing") + return + + self.trainer.write_file(self.fname, json.dumps(self.config), dir=path) diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 8ec0e021..4cec7dc0 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -676,11 +676,8 @@ def save_checkpoint_async(self) -> Tuple[PathOrStr, Future]: """ Save a checkpoint for the current step to the :data:`save_folder` asynchronously. - :param done_callback: A function that will be called with the path to the checkpoint - after it's saved successfully. - :returns: The path/URL to the checkpoint and a future which will complete when the - checkpoint is successfully saved. + checkpoint is successfully saved. """ step = self.global_step dirname = self.checkpointer.checkpoint_dirname(step) @@ -723,16 +720,19 @@ def record_metric( self._metrics[self.global_step][name] = value self._metrics_reduce_type[name] = reduce_type - def write_file(self, name: str, contents: Union[str, bytes]) -> PathOrStr: + def write_file( + self, name: str, contents: Union[str, bytes], dir: Optional[PathOrStr] = None + ) -> PathOrStr: """ - Write a file to the :data:`save_folder`. + Write a file to the :data:`save_folder` or ``dir``, if provided. :param fname: The name of the file to write. :param contents: The contents of the file to write. + :param dir: The path/URL to a directory to write the file to. Defaults to :data:`save_folder`. :returns: The path/URL of the file. """ - return self.checkpointer.write_file(self.save_folder, name, contents) + return self.checkpointer.write_file(dir or self.save_folder, name, contents) def _duration_due(self, duration: Duration) -> bool: if duration.unit == DurationUnit.steps: diff --git a/src/scripts/train/OLMo-7B.py b/src/scripts/train/OLMo-7B.py index bad042e8..c36fd3bb 100644 --- a/src/scripts/train/OLMo-7B.py +++ b/src/scripts/train/OLMo-7B.py @@ -2,7 +2,6 @@ Train a 7B OLMo model. Run this script without any arguments to see usage info. """ -import json import logging import sys from dataclasses import dataclass @@ -13,7 +12,7 @@ from olmo_core.config import Config, DType, StrEnum from olmo_core.data import DataMix, MemMapDatasetConfig, TokenizerConfig from olmo_core.distributed.parallel import DataParallelConfig, DataParallelType -from olmo_core.distributed.utils import get_num_nodes, get_rank, init_hybrid_shard_mesh +from olmo_core.distributed.utils import get_num_nodes, init_hybrid_shard_mesh from olmo_core.io import is_url from olmo_core.launch.beaker import ( BeakerEnvSecret, @@ -30,6 +29,7 @@ ) from olmo_core.train.callbacks import ( CheckpointerCallback, + ConfigSaverCallback, GPUMemoryMonitorCallback, GradClipperCallback, SchedulerCallback, @@ -174,6 +174,7 @@ def build_config(run_name: str, cluster: str, overrides: List[str]) -> Experimen cancel_check_interval=10, ), ) + .with_callback("config_saver", ConfigSaverCallback()) ) return ExperimentConfig( @@ -205,12 +206,10 @@ def train(config: ExperimentConfig): dataset = config.dataset.build() trainer = config.trainer.build(model, optim, dataset) - # Record the config to W&B and to the save folder. - if get_rank() == 0: - config_dict = config.as_config_dict() - wandb_callback = cast(WandBCallback, trainer.callbacks["wandb"]) - wandb_callback.config = config_dict - trainer.write_file("config.json", json.dumps(config_dict, indent=2)) + # Record the config to W&B and each checkpoint dir. + config_dict = config.as_config_dict() + cast(WandBCallback, trainer.callbacks["wandb"]).config = config_dict + cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict # Train. trainer.fit()