Skip to content

Commit

Permalink
Merge branch 'main' of github.com:allenai/OLMo-core into undfined/sam…
Browse files Browse the repository at this point in the history
…ple-bug
  • Loading branch information
undfined committed Jan 23, 2025
2 parents 4934c7c + 7b755c9 commit 8040525
Show file tree
Hide file tree
Showing 23 changed files with 487 additions and 176 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`.
- Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint.
- Added a callback for sending Slack notifications.
- Added `SkipStepAdamW` optimizer.
- The trainer can load model-only checkpoints now.
- Added the option to throttle checkpoint uploads to one rank from each node at a time.

### Changed

Expand All @@ -24,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- Added missing `weights_only=False` argument to fix loading train checkpoints with newer versions of PyTorch.
- Fixed bug where GCS upload does not retry on transient failures.

## [v1.7.0](https://github.com/allenai/OLMo-core/releases/tag/v1.7.0) - 2024-11-27

Expand Down
40 changes: 36 additions & 4 deletions src/olmo_core/distributed/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,11 @@
def save_state_dict(
dir: PathOrStr,
state_dict: Dict[str, Any],
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
):
"""
Save an arbitrary state dictionary to a distributed format that can loaded again with
Expand All @@ -76,11 +79,19 @@ def save_state_dict(
:param state_dict: The state dict to save.
:param process_group: The process group to use for distributed collectives.
:param save_overwrite: Overwrite existing files.
:param thread_count: Set this to override the number of threads used while writing data.
:param throttle_uploads: If this is set to ``True`` and ``dir`` is a URL then only one
rank from each node will upload data at a time.
"""
dir = _prepare_env_for_save(dir, process_group=process_group, save_overwrite=save_overwrite)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
)

Expand All @@ -93,6 +104,8 @@ def save_model_and_optim_state(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
) -> None:
"""
Save model and optimizer state dictionaries. The model state can be a sharded model, in which
Expand All @@ -115,6 +128,9 @@ def save_model_and_optim_state(
:param optim: The optimizer to save state from.
:param process_group: The process group to use for distributed collectives.
:param save_overwrite: Overwrite existing files.
:param thread_count: Set this to override the number of threads used while writing data.
:param throttle_uploads: If this is set to ``True`` and ``dir`` is a URL then only one
rank from each node will upload data at a time.
:raises FileExistsError: If the checkpoint dir exists and is non-empty unless ``save_overwrite=True``.
"""
Expand All @@ -123,7 +139,12 @@ def save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
dist_cp.state_dict_saver.save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
planner=planner,
)
Expand All @@ -137,6 +158,8 @@ def async_save_model_and_optim_state(
*,
process_group: Optional[dist.ProcessGroup] = None,
save_overwrite: bool = False,
thread_count: Optional[int] = None,
throttle_uploads: bool = False,
) -> Future[None]:
"""
An async version of :func:`save_model_and_optim_state()`.
Expand All @@ -148,7 +171,12 @@ def async_save_model_and_optim_state(
planner = DefaultSavePlanner(dedup_save_to_lowest_rank=True)
return dist_cp.state_dict_saver.async_save(
state_dict,
storage_writer=RemoteFileSystemWriter(dir),
storage_writer=RemoteFileSystemWriter(
dir,
thread_count=thread_count,
process_group=process_group,
throttle_uploads=throttle_uploads,
),
process_group=process_group,
planner=planner,
)
Expand All @@ -164,6 +192,7 @@ def load_model_and_optim_state(
key_mapping: Optional[Dict[str, str]] = None,
pre_download: bool = False,
work_dir: Optional[PathOrStr] = None,
thread_count: Optional[int] = None,
):
"""
Load model and optimizer state in-place from a checkpoint saved via :func:`save_model_and_optim_state()`.
Expand Down Expand Up @@ -201,10 +230,13 @@ def load_model_and_optim_state(
This dictionary should map current keys to keys in the checkpoint to be loaded.
:param pre_download: Download and cache relevant remote checkpoint files before trying to read from them.
:param work_dir: A working directory for caching files/directories.
:param thread_count: Set the number of threads used for certain operations.
"""
dir = normalize_path(dir)
state_dict = _prepare_state_dict(model, optim, process_group=process_group)
reader = RemoteFileSystemReader(dir, pre_download=pre_download, work_dir=work_dir)
reader = RemoteFileSystemReader(
dir, thread_count=thread_count, pre_download=pre_download, work_dir=work_dir
)

if key_mapping is not None:
metadata = reader.read_metadata()
Expand Down
36 changes: 28 additions & 8 deletions src/olmo_core/distributed/checkpoint/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, cast

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.filesystem import WriteResult
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex, StorageMeta
Expand All @@ -25,6 +27,7 @@
from torch.futures import Future

from olmo_core.aliases import PathOrStr
from olmo_core.distributed.utils import do_n_at_a_time
from olmo_core.exceptions import OLMoCheckpointError
from olmo_core.io import (
get_bytes_range,
Expand Down Expand Up @@ -154,12 +157,16 @@ def __init__(
self,
path: PathOrStr,
thread_count: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
throttle_uploads: bool = False,
) -> None:
super().__init__()
if thread_count is not None and thread_count <= 0:
raise ValueError("thread count must be at least 1")
self.path = normalize_path(path)
self.thread_count = thread_count or get_default_thread_count()
self.process_group = process_group
self.throttle_uploads = throttle_uploads
self.save_id = generate_uuid()

def reset(self, checkpoint_id: Optional[PathOrStr] = None) -> None:
Expand Down Expand Up @@ -201,22 +208,35 @@ def gen_file_name() -> str:
file_count += 1
return file_name

with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
futures = []
for bucket in _split_by_size_and_type(self.thread_count, plan.items):
def write_items(buckets: List[List[WriteItem]]) -> List[WriteResult]:
results: List[WriteResult] = []
for bucket in buckets:
file_name = gen_file_name()
path = f"{self.path}/{file_name}"
futures.append(executor.submit(_write_items, path, file_name, bucket, planner))

results = []
for f in as_completed(futures):
try:
results += f.result()
results.extend(_write_items(path, file_name, bucket, planner))
except BaseException:
# NOTE: we might get an error here that can't be pickled, which causes a different failure
# later when PyTorch tries to reduce that error across ranks. So here we just make
# sure we're raising a simple error type that can be pickled.
raise OLMoCheckpointError(f"Original error:\n{traceback.format_exc()}")
return results

results: List[WriteResult]
if self.throttle_uploads and is_url(self.path):
buckets = _split_by_size_and_type(1, plan.items)
results = do_n_at_a_time(
partial(write_items, buckets), process_group=self.process_group
)
else:
buckets = _split_by_size_and_type(self.thread_count, plan.items)
results = []
with ThreadPoolExecutor(max_workers=self.thread_count) as executor:
futures = []
for bucket in buckets:
futures.append(executor.submit(write_items, [bucket]))
for f in as_completed(futures):
results.extend(f.result())

fut: Future[List[WriteResult]] = Future()
fut.set_result(results)
Expand Down
37 changes: 36 additions & 1 deletion src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
"""

import logging
import math
import os
from datetime import timedelta
from typing import List, Optional, TypeVar
from typing import Callable, List, Optional, TypeVar, cast

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -92,6 +93,7 @@ def init_distributed(backend: str = "nccl", timeout: timedelta = timedelta(minut
"enp6s0,enp7s0,enp13s0,enp14s0,enp134s0,enp135s0,enp141s0,enp142s0",
)
set_env_var("NCCL_SOCKET_IFNAME", "enp0s12")
set_env_var("NCCL_DEBUG_SUBSYS", "INIT,NET")

if backend_supports_cuda(backend):
# Set CUDA device.
Expand Down Expand Up @@ -420,3 +422,36 @@ def get_local_tensor(x: torch.Tensor) -> torch.Tensor:
return x.to_local()
else:
return x


def do_n_at_a_time(
f: Callable[[], T],
*,
n: Optional[int] = None,
process_group: Optional[dist.ProcessGroup] = None,
world_size: Optional[int] = None,
local_rank: Optional[int] = None,
) -> T:
"""
Call a function ``f`` in a distributed context from at most ``n`` ranks at a time.
All ranks will eventually call the given function exactly once, at which point this function
will return.
:param f: The function to call from each rank.
:param n: The level of concurrency, i.e. how many ranks are allowed to call ``f`` at once.
This defaults to the number of nodes, in which case one rank from each node will
call ``f`` at a time.
:param process_group: The process group to use.
"""
world_size = world_size if world_size is not None else get_world_size(process_group)
local_rank = local_rank if local_rank is not None else get_rank(process_group)
n = n if n is not None else get_num_nodes()
group_count = math.ceil(world_size / n)
group_rank = local_rank % group_count
result: Optional[T] = None
for active_group in range(group_count):
if group_rank == active_group:
result = f()
barrier(process_group)
return cast(T, result)
1 change: 1 addition & 0 deletions src/olmo_core/internal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def build_launch_config(
# Setup python environment.
"conda shell.bash activate base",
"pip install -e '.[all]'",
"pip install --upgrade beaker-py",
# Quickly try a new version of PyTorch like this
# "pip install --upgrade --pre torch==2.6.0.dev20241112+cu121 --index-url https://download.pytorch.org/whl/nightly/cu121",
"pip freeze",
Expand Down
1 change: 1 addition & 0 deletions src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def build_common_components(
root_dir=root_dir,
cmd=[script, cmd_to_launch, run_name, cluster, *overrides],
cluster=cluster,
nccl_debug=False,
)

beaker_user = get_beaker_username()
Expand Down
51 changes: 38 additions & 13 deletions src/olmo_core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,16 +532,25 @@ def _get_gcs_client():


def _gcs_is_retriable(exc: Exception) -> bool:
from google.api_core.exceptions import BadRequest
from google.api_core.retry import if_transient_error

return if_transient_error(exc) or isinstance(exc, requests.exceptions.Timeout)
return (
if_transient_error(exc)
or isinstance(exc, requests.exceptions.Timeout)
or isinstance(exc, BadRequest) # Weird choice, but Google throws this transiently
)


def _get_gcs_retry():
from google.api_core.retry import Retry

return Retry(
predicate=_gcs_is_retriable, initial=1.0, maximum=10.0, multiplier=2.0, timeout=600.0
predicate=_gcs_is_retriable, # NOTE: it appears google might ignore this
initial=1.0,
maximum=10.0,
multiplier=2.0,
timeout=600.0,
)


Expand All @@ -554,7 +563,7 @@ def _get_gcs_conditional_retry():
return ConditionalRetryPolicy(_get_gcs_retry(), is_generation_specified, ["query_params"])


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_file_size(bucket_name: str, key: str) -> int:
from google.api_core.exceptions import NotFound

Expand All @@ -569,35 +578,51 @@ def _gcs_file_size(bucket_name: str, key: str) -> int:
return blob.size


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes: int) -> bytes:
from google.api_core.exceptions import NotFound

storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
try:
blob.reload()
blob.reload(retry=_get_gcs_retry())
except NotFound:
raise FileNotFoundError(f"gs://{bucket_name}/{key}")
return blob.download_as_bytes(
start=bytes_start, end=bytes_start + num_bytes - 1, retry=_get_gcs_retry()
start=bytes_start,
end=bytes_start + num_bytes - 1,
retry=_get_gcs_retry(),
checksum=None, # type: ignore
)


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
def _gcs_upload(source: Path, bucket_name: str, key: str, save_overwrite: bool = False):
storage_client = _get_gcs_client()
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(key)
if not save_overwrite and blob.exists():
raise FileExistsError(
f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)
blob.upload_from_filename(source, retry=_get_gcs_conditional_retry())

generation: int = 0
if blob.exists(retry=_get_gcs_retry()):
if not save_overwrite:
raise FileExistsError(
f"gs://{bucket_name}/{key} already exists. Use save_overwrite to overwrite it."
)

@retriable()
blob.reload(retry=_get_gcs_retry())
assert blob.generation is not None
generation = blob.generation

blob.upload_from_filename(
source,
if_generation_match=generation,
retry=_get_gcs_conditional_retry(),
checksum=None,
)


@retriable(retry_condition=_gcs_is_retriable)
def _gcs_clear_directory(bucket_name: str, prefix: str):
from google.api_core.exceptions import NotFound

Expand Down
Loading

0 comments on commit 8040525

Please sign in to comment.