diff --git a/CHANGELOG.md b/CHANGELOG.md index c85f4245..99f0738e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`. - Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint. - Added a callback for sending Slack notifications. +- Added `SkipStepAdamW` optimizer. +- The trainer can load model-only checkpoints now. +- Added the option to throttle checkpoint uploads to one rank from each node at a time. ### Changed @@ -24,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed - Added missing `weights_only=False` argument to fix loading train checkpoints with newer versions of PyTorch. +- Fixed bug where GCS upload does not retry on transient failures. ## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27 diff --git a/src/olmo_core/distributed/checkpoint/__init__.py b/src/olmo_core/distributed/checkpoint/__init__.py index 70d0e542..9e9bac6c 100644 --- a/src/olmo_core/distributed/checkpoint/__init__.py +++ b/src/olmo_core/distributed/checkpoint/__init__.py @@ -61,8 +61,11 @@ def save_state_dict( dir: PathOrStr, state_dict: Dict[str, Any], + *, process_group: Optional[dist.ProcessGroup] = None, save_overwrite: bool = False, + thread_count: Optional[int] = None, + throttle_uploads: bool = False, ): """ Save an arbitrary state dictionary to a distributed format that can loaded again with @@ -76,11 +79,19 @@ def save_state_dict( :param state_dict: The state dict to save. :param process_group: The process group to use for distributed collectives. :param save_overwrite: Overwrite existing files. + :param thread_count: Set this to override the number of threads used while writing data. + :param throttle_uploads: If this is set to ``True`` and ``dir`` is a URL then only one + rank from each node will upload data at a time. """ dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite) dist_cp.state_dict_saver.save( state_dict, - storage_writer=RemoteFileSystemWriter(dir), + storage_writer=RemoteFileSystemWriter( + dir, + thread_count=thread_count, + process_group=process_group, + throttle_uploads=throttle_uploads, + ), process_group=process_group, ) @@ -93,6 +104,8 @@ def save_model_and_optim_state( *, process_group: Optional[dist.ProcessGroup] = None, save_overwrite: bool = False, + thread_count: Optional[int] = None, + throttle_uploads: bool = False, ) -> None: """ Save model and optimizer state dictionaries. The model state can be a sharded model, in which @@ -115,6 +128,9 @@ def save_model_and_optim_state( :param optim: The optimizer to save state from. :param process_group: The process group to use for distributed collectives. :param save_overwrite: Overwrite existing files. + :param thread_count: Set this to override the number of threads used while writing data. + :param throttle_uploads: If this is set to ``True`` and ``dir`` is a URL then only one + rank from each node will upload data at a time. :raises FileExistsError: If the checkpoint dir exists and is non-empty unless ``save_overwrite=True``. """ @@ -123,7 +139,12 @@ def save_model_and_optim_state( planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) dist_cp.state_dict_saver.save( state_dict, - storage_writer=RemoteFileSystemWriter(dir), + storage_writer=RemoteFileSystemWriter( + dir, + thread_count=thread_count, + process_group=process_group, + throttle_uploads=throttle_uploads, + ), process_group=process_group, planner=planner, ) @@ -137,6 +158,8 @@ def async_save_model_and_optim_state( *, process_group: Optional[dist.ProcessGroup] = None, save_overwrite: bool = False, + thread_count: Optional[int] = None, + throttle_uploads: bool = False, ) -> Future[None]: """ An async version of :func:`save_model_and_optim_state()`. @@ -148,7 +171,12 @@ def async_save_model_and_optim_state( planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True) return dist_cp.state_dict_saver.async_save( state_dict, - storage_writer=RemoteFileSystemWriter(dir), + storage_writer=RemoteFileSystemWriter( + dir, + thread_count=thread_count, + process_group=process_group, + throttle_uploads=throttle_uploads, + ), process_group=process_group, planner=planner, ) @@ -164,6 +192,7 @@ def load_model_and_optim_state( key_mapping: Optional[Dict[str, str]] = None, pre_download: bool = False, work_dir: Optional[PathOrStr] = None, + thread_count: Optional[int] = None, ): """ Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`. @@ -201,10 +230,13 @@ def load_model_and_optim_state( This dictionary should map current keys to keys in the checkpoint to be loaded. :param pre_download: Download and cache relevant remote checkpoint files before trying to read from them. :param work_dir: A working directory for caching files/directories. + :param thread_count: Set the number of threads used for certain operations. """ dir = normalize_path(dir) state_dict = _prepare_state_dict(model, optim, process_group=process_group) - reader = RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir) + reader = RemoteFileSystemReader( + dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir + ) if key_mapping is not None: metadata = reader.read_metadata() diff --git a/src/olmo_core/distributed/checkpoint/filesystem.py b/src/olmo_core/distributed/checkpoint/filesystem.py index 1960c124..5b7a39dc 100644 --- a/src/olmo_core/distributed/checkpoint/filesystem.py +++ b/src/olmo_core/distributed/checkpoint/filesystem.py @@ -7,10 +7,12 @@ import traceback from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass +from functools import partial from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple, cast import torch +import torch.distributed as dist import torch.distributed.checkpoint as dist_cp from torch.distributed.checkpoint.filesystem import WriteResult from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta @@ -25,6 +27,7 @@ from torch.futures import Future from olmo_core.aliases import PathOrStr +from olmo_core.distributed.utils import do_n_at_a_time from olmo_core.exceptions import OLMoCheckpointError from olmo_core.io import ( get_bytes_range, @@ -154,12 +157,16 @@ def __init__( self, path: PathOrStr, thread_count: Optional[int] = None, + process_group: Optional[dist.ProcessGroup] = None, + throttle_uploads: bool = False, ) -> None: super().__init__() if thread_count is not None and thread_count <= 0: raise ValueError("thread count must be at least 1") self.path = normalize_path(path) self.thread_count = thread_count or get_default_thread_count() + self.process_group = process_group + self.throttle_uploads = throttle_uploads self.save_id = generate_uuid() def reset(self, checkpoint_id: Optional[PathOrStr] = None) -> None: @@ -201,22 +208,35 @@ def gen_file_name() -> str: file_count += 1 return file_name - with ThreadPoolExecutor(max_workers=self.thread_count) as executor: - futures = [] - for bucket in _split_by_size_and_type(self.thread_count, plan.items): + def write_items(buckets: List[List[WriteItem]]) -> List[WriteResult]: + results: List[WriteResult] = [] + for bucket in buckets: file_name = gen_file_name() path = f"{self.path}/{file_name}" - futures.append(executor.submit(_write_items, path, file_name, bucket, planner)) - - results = [] - for f in as_completed(futures): try: - results += f.result() + results.extend(_write_items(path, file_name, bucket, planner)) except BaseException: # NOTE: we might get an error here that can't be pickled, which causes a different failure # later when PyTorch tries to reduce that error across ranks. So here we just make # sure we're raising a simple error type that can be pickled. raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}") + return results + + results: List[WriteResult] + if self.throttle_uploads and is_url(self.path): + buckets = _split_by_size_and_type(1, plan.items) + results = do_n_at_a_time( + partial(write_items, buckets), process_group=self.process_group + ) + else: + buckets = _split_by_size_and_type(self.thread_count, plan.items) + results = [] + with ThreadPoolExecutor(max_workers=self.thread_count) as executor: + futures = [] + for bucket in buckets: + futures.append(executor.submit(write_items, [bucket])) + for f in as_completed(futures): + results.extend(f.result()) fut: Future[List[WriteResult]] = Future() fut.set_result(results) diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index 93b03fb8..c213439a 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -3,9 +3,10 @@ """ import logging +import math import os from datetime import timedelta -from typing import List, Optional, TypeVar +from typing import Callable, List, Optional, TypeVar, cast import torch import torch.distributed as dist @@ -92,6 +93,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut "enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0", ) set_env_var("NCCL_SOCKET_IFNAME", "enp0s12") + set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET") if backend_supports_cuda(backend): # Set CUDA device. @@ -420,3 +422,36 @@ def get_local_tensor(x: torch.Tensor) -> torch.Tensor: return x.to_local() else: return x + + +def do_n_at_a_time( + f: Callable[[], T], + *, + n: Optional[int] = None, + process_group: Optional[dist.ProcessGroup] = None, + world_size: Optional[int] = None, + local_rank: Optional[int] = None, +) -> T: + """ + Call a function ``f`` in a distributed context from at most ``n`` ranks at a time. + + All ranks will eventually call the given function exactly once, at which point this function + will return. + + :param f: The function to call from each rank. + :param n: The level of concurrency, i.e. how many ranks are allowed to call ``f`` at once. + This defaults to the number of nodes, in which case one rank from each node will + call ``f`` at a time. + :param process_group: The process group to use. + """ + world_size = world_size if world_size is not None else get_world_size(process_group) + local_rank = local_rank if local_rank is not None else get_rank(process_group) + n = n if n is not None else get_num_nodes() + group_count = math.ceil(world_size / n) + group_rank = local_rank % group_count + result: Optional[T] = None + for active_group in range(group_count): + if group_rank == active_group: + result = f() + barrier(process_group) + return cast(T, result) diff --git a/src/olmo_core/internal/common.py b/src/olmo_core/internal/common.py index 1c2d426a..d660f0a0 100644 --- a/src/olmo_core/internal/common.py +++ b/src/olmo_core/internal/common.py @@ -102,6 +102,7 @@ def build_launch_config( # Setup python environment. "conda shell.bash activate base", "pip install -e '.[all]'", + "pip install --upgrade beaker-py", # Quickly try a new version of PyTorch like this # "pip install --upgrade --pre torch==2.6.0.dev20241112+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121", "pip freeze", diff --git a/src/olmo_core/internal/experiment.py b/src/olmo_core/internal/experiment.py index 4d1e9ee7..1015a60d 100644 --- a/src/olmo_core/internal/experiment.py +++ b/src/olmo_core/internal/experiment.py @@ -130,6 +130,7 @@ def build_common_components( root_dir=root_dir, cmd=[script, cmd_to_launch, run_name, cluster, *overrides], cluster=cluster, + nccl_debug=False, ) beaker_user = get_beaker_username() diff --git a/src/olmo_core/io.py b/src/olmo_core/io.py index 5fda2741..42ff17f3 100644 --- a/src/olmo_core/io.py +++ b/src/olmo_core/io.py @@ -532,16 +532,25 @@ def _get_gcs_client(): def _gcs_is_retriable(exc: Exception) -> bool: + from google.api_core.exceptions import BadRequest from google.api_core.retry import if_transient_error - return if_transient_error(exc) or isinstance(exc, requests.exceptions.Timeout) + return ( + if_transient_error(exc) + or isinstance(exc, requests.exceptions.Timeout) + or isinstance(exc, BadRequest) # Weird choice, but Google throws this transiently + ) def _get_gcs_retry(): from google.api_core.retry import Retry return Retry( - predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0 + predicate=_gcs_is_retriable, # NOTE: it appears google might ignore this + initial=1.0, + maximum=10.0, + multiplier=2.0, + timeout=600.0, ) @@ -554,7 +563,7 @@ def _get_gcs_conditional_retry(): return ConditionalRetryPolicy(_get_gcs_retry(), is_generation_specified, ["query_params"]) -@retriable() +@retriable(retry_condition=_gcs_is_retriable) def _gcs_file_size(bucket_name: str, key: str) -> int: from google.api_core.exceptions import NotFound @@ -569,7 +578,7 @@ def _gcs_file_size(bucket_name: str, key: str) -> int: return blob.size -@retriable() +@retriable(retry_condition=_gcs_is_retriable) def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes: from google.api_core.exceptions import NotFound @@ -577,27 +586,43 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) try: - blob.reload() + blob.reload(retry=_get_gcs_retry()) except NotFound: raise FileNotFoundError(f"gs://{bucket_name}/{key}") return blob.download_as_bytes( - start=bytes_start, end=bytes_start + num_bytes - 1, retry=_get_gcs_retry() + start=bytes_start, + end=bytes_start + num_bytes - 1, + retry=_get_gcs_retry(), + checksum=None, # type: ignore ) -@retriable() +@retriable(retry_condition=_gcs_is_retriable) def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False): storage_client = _get_gcs_client() bucket = storage_client.bucket(bucket_name) blob = bucket.blob(key) - if not save_overwrite and blob.exists(): - raise FileExistsError( - f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." - ) - blob.upload_from_filename(source, retry=_get_gcs_conditional_retry()) + generation: int = 0 + if blob.exists(retry=_get_gcs_retry()): + if not save_overwrite: + raise FileExistsError( + f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it." + ) -@retriable() + blob.reload(retry=_get_gcs_retry()) + assert blob.generation is not None + generation = blob.generation + + blob.upload_from_filename( + source, + if_generation_match=generation, + retry=_get_gcs_conditional_retry(), + checksum=None, + ) + + +@retriable(retry_condition=_gcs_is_retriable) def _gcs_clear_directory(bucket_name: str, prefix: str): from google.api_core.exceptions import NotFound diff --git a/src/olmo_core/launch/beaker.py b/src/olmo_core/launch/beaker.py index 435142f3..7d234d31 100644 --- a/src/olmo_core/launch/beaker.py +++ b/src/olmo_core/launch/beaker.py @@ -317,12 +317,23 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: "#!/usr/bin/env bash", "set -exuo pipefail", "[[ -d /var/lib/tcpxo/lib64 ]] && export LD_LIBRARY_PATH=/var/lib/tcpxo/lib64:$LD_LIBRARY_PATH", + # Setup the kernel cache directory used by pytorch + "mkdir -p /root/.cache/torch/kernels && export PYTORCH_KERNEL_CACHE_PATH=/root/.cache/torch/kernels", "mkdir -p /olmo-core-runtime", "cd /olmo-core-runtime", *self.setup_steps, ] if torchrun: + if any(["augusta" in cluster for cluster in self.clusters]): + entrypoint_script.append( + "export BEAKER_REPLICA_RANK=$(" + "python -m olmo_core.launch.reorder_ranks_in_gcp " + "${BEAKER_REPLICA_RANK} " + "${BEAKER_REPLICA_COUNT} " + "${BEAKER_LEADER_REPLICA_HOSTNAME}" + ")" + ) entrypoint_script.append(" ".join(self._get_torchrun_cmd()) + ' "$@"') else: entrypoint_script.append('python "$@"') @@ -341,7 +352,7 @@ def build_experiment_spec(self, torchrun: bool = True) -> ExperimentSpec: leader_selection=self.num_nodes > 1, host_networking=self.num_nodes > 1 or any(["augusta" in cluster for cluster in self.clusters]), - propagate_failure=True if self.num_nodes > 1 else None, + propagate_failure=False if self.num_nodes > 1 else None, propagate_preemption=True if self.num_nodes > 1 else None, synchronized_start_timeout="90m" if self.num_nodes > 1 else None, resources=TaskResources(gpu_count=self.num_gpus, shared_memory="10GiB"), diff --git a/src/olmo_core/launch/reorder_ranks_in_gcp.py b/src/olmo_core/launch/reorder_ranks_in_gcp.py new file mode 100644 index 00000000..d1381ea2 --- /dev/null +++ b/src/olmo_core/launch/reorder_ranks_in_gcp.py @@ -0,0 +1,70 @@ +import argparse +import sys + +import requests +import torch.distributed as dist +from urllib3.exceptions import MaxRetryError, NameResolutionError + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("rank", type=int, help="Worker number") + parser.add_argument("world_size", type=int, help="Total number of workers") + parser.add_argument("master_addr", help="Hostname of worker 0") + parser.add_argument("--master_port", type=int, default=29501, help="Port for TCPStore") + parser.add_argument("--debug", action="store_true", help="Enable debug mode (outside of GCP)") + args = parser.parse_args() + + # Create or connect to the store + store = dist.TCPStore( + host_name=args.master_addr, + port=args.master_port, + world_size=args.world_size, + is_master=(args.rank == 0), + ) + + # Get our own host id + if args.debug: + import socket + + host_id = f"{socket.gethostname()}_{args.rank}" + else: + try: + response = requests.get( + "http://metadata.google.internal/computeMetadata/v1/instance/attributes/physical_host", + headers={"Metadata-Flavor": "Google"}, + ) + assert response.status_code == 200 + host_id = response.text.strip() + except requests.exceptions.ConnectionError as e: + # Unwrap the exception + e = e.args[0] + if not isinstance(e, MaxRetryError): + raise + e = e.reason + if not isinstance(e, NameResolutionError): + raise + # Seems we called this outside of GCP, so we do nothing and just print our original rank. + print(args.rank) + sys.exit(0) + + # Find the index of our host id + store.set(f"node_{args.rank}_hostid", host_id) + store.wait([f"node_{i}_hostid" for i in range(args.world_size)]) + all_host_ids = [store.get(f"node_{i}_hostid").decode("UTF-8") for i in range(args.world_size)] + assert len(set(all_host_ids)) == len(all_host_ids) + assert host_id in all_host_ids + rank0_host_id = all_host_ids[0] + all_host_ids.sort() + # Rank 0 needs to remain rank 0, so we reshuffle around it + rank0_index = all_host_ids.index(rank0_host_id) + all_host_ids = all_host_ids[rank0_index:] + all_host_ids[:rank0_index] + print(all_host_ids.index(host_id)) + + # Make sure we're all done before exiting + store.set(f"node_{args.rank}_done", host_id) + store.wait([f"node_{i}_done" for i in range(args.world_size)]) + + +if __name__ == "__main__": + main() diff --git a/src/olmo_core/nn/transformer/config.py b/src/olmo_core/nn/transformer/config.py index f77ec4e2..44d34252 100644 --- a/src/olmo_core/nn/transformer/config.py +++ b/src/olmo_core/nn/transformer/config.py @@ -460,19 +460,22 @@ def olmo2_13B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": ) @classmethod - def olmo2_26B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": + def olmo2_32B(cls, vocab_size: int, **kwargs) -> "TransformerConfig": """ - A 26B OLMo model config. + A 32B OLMo model config. """ + d_model = 5120 return cls.llama_like( vocab_size=vocab_size, - d_model=7168, - n_layers=kwargs.pop("n_layers", 40), - n_heads=kwargs.pop("n_heads", 56), + d_model=d_model, + n_layers=kwargs.pop("n_layers", 64), + n_heads=kwargs.pop("n_heads", 40), + n_kv_heads=kwargs.pop("n_kv_heads", 8), block_name=kwargs.pop("block_name", TransformerBlockType.reordered_norm), qk_norm=kwargs.pop("qk_norm", True), rope_theta=kwargs.pop("rope_theta", 500_000), - hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 1024), + hidden_size_multiple_of=kwargs.pop("hidden_size_multiple_of", 512), + hidden_size_multiplier=kwargs.pop("hidden_size_multiplier", 27648 / (8 * d_model / 3)), layer_norm_eps=1e-6, **kwargs, ) diff --git a/src/olmo_core/optim/__init__.py b/src/olmo_core/optim/__init__.py index 0e1cf986..c050d5e9 100644 --- a/src/olmo_core/optim/__init__.py +++ b/src/olmo_core/optim/__init__.py @@ -1,5 +1,5 @@ from .adam import AdamConfig -from .adamw import AdamWConfig +from .adamw import AdamWConfig, SkipStepAdamW, SkipStepAdamWConfig from .config import OptimConfig, OptimGroupOverride from .lion import Lion, LionConfig, SkipStepLion, SkipStepLionConfig from .scheduler import ( @@ -18,6 +18,8 @@ "OptimGroupOverride", "SkipStepOptimizer", "AdamWConfig", + "SkipStepAdamWConfig", + "SkipStepAdamW", "AdamConfig", "LionConfig", "Lion", diff --git a/src/olmo_core/optim/adamw.py b/src/olmo_core/optim/adamw.py index bc5f1e46..e4a24c90 100644 --- a/src/olmo_core/optim/adamw.py +++ b/src/olmo_core/optim/adamw.py @@ -1,4 +1,3 @@ -import math from dataclasses import dataclass from typing import Optional, Tuple, Type @@ -6,9 +5,9 @@ import torch.nn as nn from .config import OptimConfig +from .skip_step_optimizer import SkipStepOptimizer -# TODO: use this when we implement a "skip step" version of AdamW. def adamw_step( p: nn.Parameter, *, @@ -18,7 +17,7 @@ def adamw_step( weight_decay: float, exp_avg: torch.Tensor, exp_avg_sq: torch.Tensor, - step: int, + step: torch.Tensor, step_factor: torch.Tensor, ): if p.grad is None: @@ -34,19 +33,87 @@ def adamw_step( exp_avg_sq.mul_(1 - step_factor * (1 - beta2)) exp_avg_sq.add_(step_factor * p.grad * p.grad, alpha=1 - beta2) - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step + bias_correction1 = 1 - beta1 ** (step + 1) + bias_correction2 = 1 - beta2 ** (step + 1) step_size = lr / bias_correction1 - bias_correction2_sqrt = math.sqrt(bias_correction2) - denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) + denom = (exp_avg_sq.sqrt() / bias_correction2.sqrt()).add_(eps) update = -step_size * torch.div(exp_avg, denom) update.mul_(step_factor) p.add_(update) +class SkipStepAdamW(SkipStepOptimizer): + """ + A "skip step" version of :class:`AdamW`. + """ + + def __init__( + self, + params, + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + foreach: Optional[bool] = None, + fused: Optional[bool] = None, + rolling_interval_length: int = 128, + sigma_factor: int = 6, + ) -> None: + assert lr > 0.0 + assert all([0.0 <= beta <= 1.0 for beta in betas]) + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, foreach=foreach, fused=fused + ) + super().__init__( + params, + defaults, + rolling_interval_length=rolling_interval_length, + sigma_factor=sigma_factor, + ) + self._step_skipped: Optional[torch.Tensor] = None + + @property + def step_skipped(self) -> torch.Tensor: + if self._step_skipped is not None: + return self._step_skipped + else: + return torch.tensor(0.0) + + @torch.no_grad() + def step(self, closure=None) -> None: + if closure is not None: + with torch.enable_grad(): + closure() + + step_factor = self.get_step_factor() + self._step_skipped = 1 - step_factor + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + state = self.state[p] + if len(state) == 0: + state["step"] = torch.tensor(0.0, dtype=torch.float32, device=p.device) + state["exp_avg"] = torch.zeros_like(p) + state["exp_avg_sq"] = torch.zeros_like(p) + + adamw_step( + p, + lr=group["lr"], + betas=group["betas"], + eps=group["eps"], + weight_decay=group["weight_decay"], + exp_avg=state["exp_avg"], + exp_avg_sq=state["exp_avg_sq"], + step=state["step"], + step_factor=step_factor, + ) + + @dataclass class AdamWConfig(OptimConfig): # NOTE: omagaconf doesn't like "OptimConfig[torch.optim.AdamW]" """ @@ -63,3 +130,21 @@ class AdamWConfig(OptimConfig): # NOTE: omagaconf doesn't like "OptimConfig[tor @classmethod def optimizer(cls) -> Type[torch.optim.AdamW]: return torch.optim.AdamW + + +@dataclass +class SkipStepAdamWConfig(OptimConfig): + """ + Configuration class for building a :class:`SkipStepAdamW` optimizer. + """ + + lr: float = 1e-3 + betas: Tuple[float, float] = (0.9, 0.999) + eps: float = 1e-8 + weight_decay: float = 1e-2 + rolling_interval_length: int = 128 + sigma_factor: int = 6 + + @classmethod + def optimizer(cls) -> Type[SkipStepAdamW]: + return SkipStepAdamW diff --git a/src/olmo_core/optim/skip_step_optimizer.py b/src/olmo_core/optim/skip_step_optimizer.py index 98ada1bd..40b0b034 100644 --- a/src/olmo_core/optim/skip_step_optimizer.py +++ b/src/olmo_core/optim/skip_step_optimizer.py @@ -91,17 +91,20 @@ def get_step_factor(self) -> torch.Tensor: The tensor can be used within the optimizer's step computation to essentially skip a step without a host-device sync. """ - if len(self._losses) < max(20, self.rolling_interval_length // 2): + if len(self._losses) < max(2, self.rolling_interval_length // 2): return torch.tensor(1.0).to(device=self.device, non_blocking=True) loss_std, loss_mean = torch.std_mean(torch.stack(self._losses[:-1])) if self._grad_norms: grad_norm_std, grad_norm_mean = torch.std_mean(torch.stack(self._grad_norms[:-1])) - return ((self.latest_loss - loss_mean) <= self.sigma_factor * loss_std) and ( - (self.latest_grad_norm - grad_norm_mean) <= self.sigma_factor * grad_norm_std + step_factor = torch.logical_and( + (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std, + (self.latest_grad_norm - grad_norm_mean) <= self.sigma_factor * grad_norm_std, ) else: - return (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std + step_factor = (self.latest_loss - loss_mean) <= self.sigma_factor * loss_std + + return step_factor.float() @property def step_skipped(self) -> torch.Tensor: diff --git a/src/olmo_core/train/__init__.py b/src/olmo_core/train/__init__.py index ba59008b..e14f3dc7 100644 --- a/src/olmo_core/train/__init__.py +++ b/src/olmo_core/train/__init__.py @@ -75,7 +75,7 @@ def prepare_training_environment( *, seed: Optional[int] = None, backend: Optional[str] = "cpu:gloo,cuda:nccl", - timeout: timedelta = timedelta(minutes=10), + timeout: timedelta = timedelta(minutes=30), log_filter_type: Optional[LogFilterType] = None, ): """ diff --git a/src/olmo_core/train/callbacks/evaluator_callback.py b/src/olmo_core/train/callbacks/evaluator_callback.py index ea2bfa58..556492b7 100644 --- a/src/olmo_core/train/callbacks/evaluator_callback.py +++ b/src/olmo_core/train/callbacks/evaluator_callback.py @@ -129,7 +129,7 @@ def build(self, trainer: "Trainer") -> Optional[Callback]: eval_batch_size = ( self.eval_batch_size if self.eval_batch_size is not None - else trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) + else 2 * trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group) ) dataset = self.eval_dataset.build() if not isinstance(dataset, NumpyPaddedFSLDataset): diff --git a/src/olmo_core/train/callbacks/grad_clipper.py b/src/olmo_core/train/callbacks/grad_clipper.py index 0a0ebbcb..97ad3b8d 100644 --- a/src/olmo_core/train/callbacks/grad_clipper.py +++ b/src/olmo_core/train/callbacks/grad_clipper.py @@ -4,6 +4,7 @@ import torch.nn as nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from olmo_core.distributed.utils import get_local_tensor from olmo_core.optim import SkipStepOptimizer from .callback import Callback @@ -26,6 +27,8 @@ def pre_optim_step(self): self.trainer.model.parameters(), self.max_grad_norm ) + grad_norm = get_local_tensor(grad_norm.detach()) + # NOTE: grad norm is already reduced over ranks, so we set `reduce_type` to `None`. self.trainer.record_metric("optim/total grad norm", grad_norm, reduce_type=None) if isinstance(self.trainer.optim, SkipStepOptimizer): diff --git a/src/olmo_core/train/checkpoint.py b/src/olmo_core/train/checkpoint.py index b262fe81..1c260098 100644 --- a/src/olmo_core/train/checkpoint.py +++ b/src/olmo_core/train/checkpoint.py @@ -19,10 +19,17 @@ from ..config import Config from ..distributed.checkpoint import ( async_save_model_and_optim_state, + get_checkpoint_metadata, load_model_and_optim_state, save_model_and_optim_state, ) -from ..distributed.utils import barrier, get_fs_local_rank, get_rank, is_distributed +from ..distributed.utils import ( + barrier, + get_fs_local_rank, + get_rank, + is_distributed, + scatter_object, +) from ..exceptions import OLMoConfigurationError from ..io import ( clear_directory, @@ -48,6 +55,9 @@ class CheckpointerConfig(Config): work_dir: Optional[str] = None save_overwrite: Optional[bool] = None pre_download: bool = False + save_thread_count: Optional[int] = None + load_thread_count: Optional[int] = None + throttle_uploads: bool = False def build(self, process_group: Optional[dist.ProcessGroup] = None, **kwargs) -> "Checkpointer": kwargs = {**self.as_dict(exclude_none=True, recurse=False), **kwargs} @@ -75,6 +85,9 @@ class Checkpointer: save_overwrite: bool = False pre_download: bool = False process_group: Optional[dist.ProcessGroup] = None + save_thread_count: Optional[int] = None + load_thread_count: Optional[int] = None + throttle_uploads: bool = False def __post_init__(self): self.work_dir = Path(self.work_dir) @@ -100,6 +113,8 @@ def save(self, dir: PathOrStr, model: nn.Module, optim: Optimizer, train_state: optim, process_group=self.process_group, save_overwrite=self.save_overwrite, + thread_count=self.save_thread_count, + throttle_uploads=self.throttle_uploads, ) self._save_metadata(dir, CheckpointMetadata()) @@ -129,6 +144,8 @@ def save_async( optim, process_group=self.process_group, save_overwrite=self.save_overwrite, + thread_count=self.save_thread_count, + throttle_uploads=self.throttle_uploads, ) def done_callback(fut: Future): @@ -146,8 +163,8 @@ def load( model: nn.Module, optim: Optimizer, *, - load_optimizer_state: bool = True, - load_trainer_state: bool = True, + load_optimizer_state: Optional[bool] = None, + load_trainer_state: Optional[bool] = None, key_mapping: Optional[Dict[str, str]] = None, ) -> Optional[Dict[str, Any]]: """ @@ -158,27 +175,53 @@ def load( # Maybe load trainer state. trainer_state: Optional[Dict[str, Any]] = None - if load_trainer_state: - try: - trainer_state = torch.load( - cached_path(f"{dir}/train/rank{get_rank()}.pt", quiet=True), weights_only=False - ) - except FileNotFoundError: - # Fall back to rank 0 train state. - # This can happen when we're restoring a checkpoint with a different world size. - trainer_state = torch.load( - cached_path(f"{dir}/train/rank0.pt", quiet=True), weights_only=False - ) + if load_trainer_state is not False: + # Try loading the given rank's state first, then fall back to rank 0 train state if it + # doesn't exist, which can happen when we're restoring a checkpoint with a different world size. + for path in (f"{dir}/train/rank{get_rank()}.pt", f"{dir}/train/rank0.pt"): + try: + trainer_state = torch.load(cached_path(path, quiet=True), weights_only=False) + break + except FileNotFoundError: + pass + + if load_trainer_state is True and trainer_state is None: + raise FileNotFoundError(f"Missing trainer state in checkpoint dir '{dir}'") # Load model and optimizer state. + model_and_optim_dir: str = f"{dir}/model_and_optim" + if get_rank(self.process_group) == 0: + try: + metadata = get_checkpoint_metadata(model_and_optim_dir) + except FileNotFoundError as exc: + # Try base directory, which could be the case if user is trying to load model weights + # (possibly with optimizer state), and not an actual train checkpoint. + if trainer_state is None: + metadata = get_checkpoint_metadata(dir) + model_and_optim_dir = dir + else: + raise FileNotFoundError(f"Missing checkpointing metadata in '{dir}'") from exc + + if load_optimizer_state is None: + for key in metadata.state_dict_metadata.keys(): + if key.startswith("optim."): + load_optimizer_state = True + break + else: + load_optimizer_state = False + + model_and_optim_dir = scatter_object(model_and_optim_dir, group=self.process_group) + load_optimizer_state = scatter_object(load_optimizer_state, group=self.process_group) + load_model_and_optim_state( - f"{dir}/model_and_optim", + model_and_optim_dir, model, optim if load_optimizer_state else None, process_group=self.process_group, key_mapping=key_mapping, pre_download=is_url(dir) and self.pre_download, work_dir=self.work_dir, + thread_count=self.load_thread_count, ) return trainer_state @@ -233,6 +276,8 @@ def dir_is_checkpoint(cls, dir: PathOrStr) -> bool: Check if a directory is a checkpoint directory. """ dir = normalize_path(dir) + if file_exists(f"{dir}/.metadata"): # just model (and maybe optim state), no trainer state + return True paths_to_check = [ f"{dir}/train/rank0.pt", f"{dir}/model_and_optim/.metadata", @@ -299,7 +344,7 @@ def _save_train_state(self, dir: PathOrStr, wd: Path, train_state: Dict[str, Any # NOTE: if 'dir' is a URL, the 'wd' will be a different temp dir for each rank. if is_url(dir) or get_fs_local_rank() == 0: train_dir.mkdir(exist_ok=True, parents=True) - wait_for(train_dir.exists, description=f"Waiting on '{train_dir}' to be created...") + wait_for(train_dir.exists, description=f"Waiting for '{train_dir}' to be created...") torch.save(train_state, train_dir / f"rank{get_rank()}.pt") def _save_metadata(self, dir: PathOrStr, metadata: CheckpointMetadata): diff --git a/src/olmo_core/train/trainer.py b/src/olmo_core/train/trainer.py index 0c8d17aa..213f3277 100644 --- a/src/olmo_core/train/trainer.py +++ b/src/olmo_core/train/trainer.py @@ -668,7 +668,11 @@ def load_state_dict(self, state_dict: TrainerStateDict): ) def load_checkpoint( - self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True + self, + dir: PathOrStr, + *, + load_optimizer_state: Optional[bool] = None, + load_trainer_state: Optional[bool] = None, ): """ Load a checkpoint. @@ -698,8 +702,7 @@ def load_checkpoint( load_trainer_state=load_trainer_state, key_mapping=self.load_key_mapping, ) - if load_trainer_state: - assert trainer_state is not None + if trainer_state is not None: self.load_state_dict(cast(TrainerStateDict, trainer_state)) for callback in self.callbacks.values(): @@ -709,7 +712,11 @@ def load_checkpoint( log.info("Checkpoint successfully loaded") def maybe_load_checkpoint( - self, dir: PathOrStr, *, load_optimizer_state: bool = True, load_trainer_state: bool = True + self, + dir: PathOrStr, + *, + load_optimizer_state: Optional[bool] = None, + load_trainer_state: Optional[bool] = None, ) -> bool: """ Like :meth:`load_checkpoint()` but is a no-op if there is no checkpoint in the ``dir`` provided. diff --git a/src/scripts/train/OLMo2-26B.py b/src/scripts/train/OLMo2-26B.py deleted file mode 100644 index 6453407c..00000000 --- a/src/scripts/train/OLMo2-26B.py +++ /dev/null @@ -1,102 +0,0 @@ -""" -Train a 26B OLMo model. Run this script without any arguments to see usage info. -""" - -import logging - -from olmo_core.config import DType -from olmo_core.distributed.parallel import DataParallelType -from olmo_core.float8 import Float8Config -from olmo_core.internal.experiment import CommonComponents, main -from olmo_core.nn.transformer import ( - TransformerActivationCheckpointingConfig, - TransformerActivationCheckpointingMode, - TransformerConfig, - TransformerDataParallelConfig, -) -from olmo_core.optim import AdamWConfig, OptimGroupOverride -from olmo_core.train import TrainerConfig -from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback - -log = logging.getLogger(__name__) - - -def build_model_config(common: CommonComponents) -> TransformerConfig: - compile = True - return TransformerConfig.olmo2_26B( - vocab_size=common.tokenizer.padded_vocab_size(), - compile=compile, - fused_ops=False, - use_flash=not compile, - dp_config=TransformerDataParallelConfig( - name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 - ), - ac_config=TransformerActivationCheckpointingConfig( - mode=TransformerActivationCheckpointingMode.full - ), - float8_config=Float8Config(compile=compile, enabled=False), - ) - - -def build_optim_config(common: CommonComponents) -> AdamWConfig: - del common - return AdamWConfig( - lr=6e-4, - weight_decay=0.1, - betas=(0.9, 0.95), - group_overrides=[ - OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) - ], - fused=True, - ) - - -def build_trainer_config(common: CommonComponents) -> TrainerConfig: - return ( - TrainerConfig( - save_folder=common.save_folder, - rank_microbatch_size=4 * 4096, - save_overwrite=True, - metrics_collect_interval=10, - cancel_check_interval=1, - z_loss_multiplier=1e-5, - compile_loss=True, - ) - .with_callback( - "checkpointer", - CheckpointerCallback( - save_interval=10_000, - ephemeral_save_interval=250, - save_async=True, - ), - ) - .with_callback( - "comet", - CometCallback( - name=common.run_name, - workspace="ai2", - project="OLMo-core-26B", - enabled=True, - cancel_check_interval=10, - ), - ) - .with_callback( - "wandb", - WandBCallback( - name=common.run_name, - entity="ai2-llm", - project="OLMo-core-26B", - enabled=False, - cancel_check_interval=10, - ), - ) - ) - - -if __name__ == "__main__": - main( - global_batch_size=2048 * 4096, - model_config_builder=build_model_config, - optim_config_builder=build_optim_config, - trainer_config_builder=build_trainer_config, - ) diff --git a/src/test/distributed/checkpoint/filesystem_test.py b/src/test/distributed/checkpoint/filesystem_test.py index 205d4e0a..c89c627e 100644 --- a/src/test/distributed/checkpoint/filesystem_test.py +++ b/src/test/distributed/checkpoint/filesystem_test.py @@ -13,7 +13,7 @@ from ..utils import BACKENDS, get_default_device, run_distributed_test -def run_save_and_load_with_dtensors(dir): +def run_save_and_load_with_dtensors(dir, throttle: bool = False): mesh = init_device_mesh(get_default_device().type, (dist.get_world_size(),)) x_full = torch.randn(4, 4, device=get_default_device()) @@ -30,7 +30,7 @@ def run_save_and_load_with_dtensors(dir): distcp.state_dict_saver.save( {"x": x, "y": y}, checkpoint_id=dir, - storage_writer=RemoteFileSystemWriter(dir, thread_count=2), + storage_writer=RemoteFileSystemWriter(dir, thread_count=2, throttle_uploads=throttle), ) # Now create new sharded copies with a different sharding strategy and load the checkpoint. @@ -51,11 +51,17 @@ def run_save_and_load_with_dtensors(dir): @pytest.mark.parametrize("backend", BACKENDS) def test_save_and_load_locally_with_dtensors(backend, tmp_path): - run_distributed_test(run_save_and_load_with_dtensors, backend=backend, func_args=(tmp_path,)) + run_distributed_test( + run_save_and_load_with_dtensors, + backend=backend, + func_args=(tmp_path,), + start_method="spawn", + ) @pytest.mark.parametrize("backend", BACKENDS) -def test_save_and_load_remotely_with_dtensors(backend, s3_checkpoint_dir): +@pytest.mark.parametrize("throttle", [True, False]) +def test_save_and_load_remotely_with_dtensors(backend, s3_checkpoint_dir, throttle): from botocore.exceptions import NoCredentialsError try: @@ -66,6 +72,6 @@ def test_save_and_load_remotely_with_dtensors(backend, s3_checkpoint_dir): run_distributed_test( run_save_and_load_with_dtensors, backend=backend, - func_args=(s3_checkpoint_dir,), + func_args=(s3_checkpoint_dir, throttle), start_method="spawn", # NOTE: forking causes a crash with boto3 ) diff --git a/src/test/distributed/utils.py b/src/test/distributed/utils.py index 1cc82fe8..01b7fa2d 100644 --- a/src/test/distributed/utils.py +++ b/src/test/distributed/utils.py @@ -1,5 +1,6 @@ import datetime import logging +import os import sys from typing import Any, Callable, Dict, Optional, Tuple @@ -8,7 +9,11 @@ import torch.distributed as dist import torch.multiprocessing as mp -from olmo_core.distributed.utils import is_distributed +from olmo_core.distributed.utils import ( + OLMO_LOCAL_WORLD_SIZE_ENV_VAR, + OLMO_NUM_NODES_ENV_VAR, + is_distributed, +) from ..utils import ( DEVICES, @@ -115,6 +120,9 @@ def log_record_factory(*args, **kwargs) -> logging.LogRecord: timeout=datetime.timedelta(seconds=120), ) + os.environ.setdefault(OLMO_NUM_NODES_ENV_VAR, "1") + os.environ.setdefault(OLMO_LOCAL_WORLD_SIZE_ENV_VAR, str(world_size)) + log.info("Starting test...") if "nccl" in backend: diff --git a/src/test/distributed/utils_test.py b/src/test/distributed/utils_test.py index 0e4ce4bc..90e313d5 100644 --- a/src/test/distributed/utils_test.py +++ b/src/test/distributed/utils_test.py @@ -1,3 +1,5 @@ +from functools import partial + import pytest import torch.distributed as dist @@ -18,3 +20,20 @@ def scatter_object(): @pytest.mark.parametrize("backend", BACKENDS) def test_scatter_object(backend: str): run_distributed_test(scatter_object, backend=backend) + + +@pytest.mark.parametrize("n, world_size", [(2, 1), (8, 64)]) +def test_do_n_at_a_time(n: int, world_size: int): + times_called = 0 + calling_ranks = set() + + def func(rank: int): + nonlocal times_called + times_called += 1 + calling_ranks.add(rank) + + for rank in range(world_size): + dist_utils.do_n_at_a_time(partial(func, rank), n=n, world_size=world_size, local_rank=rank) + + assert times_called == world_size + assert calling_ranks == set(range(world_size)) diff --git a/src/test/optim/adamw_test.py b/src/test/optim/adamw_test.py index 5756f9a6..a792ace9 100644 --- a/src/test/optim/adamw_test.py +++ b/src/test/optim/adamw_test.py @@ -1,7 +1,10 @@ +from test.utils import DEVICES + +import pytest import torch import torch.nn as nn -from olmo_core.optim import AdamWConfig, OptimGroupOverride +from olmo_core.optim import AdamWConfig, OptimGroupOverride, SkipStepAdamWConfig class MyModel(nn.Module): @@ -43,3 +46,33 @@ def test_adamw_config_to_optim_with_group_overrides(): for group in optim.param_groups: assert "initial_lr" in group + + +@pytest.mark.parametrize("device", DEVICES) +def test_adamw(device: torch.device): + config = AdamWConfig() + model = MyModel().train().to(device) + optim = config.build(model) + + for group in optim.param_groups: + assert "initial_lr" in group + + # Take a step. + optim.zero_grad(set_to_none=True) + model(torch.randint(0, 1024, (2, 8), device=device).int()).sum().backward() + optim.step() + + +@pytest.mark.parametrize("device", DEVICES) +def test_skip_step_adamw(device: torch.device): + config = SkipStepAdamWConfig() + model = MyModel().train().to(device) + optim = config.build(model) + + for group in optim.param_groups: + assert "initial_lr" in group + + # Take a step. + optim.zero_grad(set_to_none=True) + model(torch.randint(0, 1024, (2, 8), device=device).int()).sum().backward() + optim.step()