Skip to content

Commit

Permalink
Make training on the MPS device work (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkgr authored Jan 24, 2025
1 parent b4a195b commit 075a36a
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 52 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@ site/
/wandb/
/scratch/
core
/dataset-cache/
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new LR schedulers: `LinearWithWarmup`, `InvSqrtWithWarmup`, `ConstantWithWarmup`, `SequentialScheduler`.
- Added option to pre-download checkpoint files from remote storage before trying to load a checkpoint.
- Added a callback for sending Slack notifications.
- Makes the MPS device work on Apple Silicon
- Added `SkipStepAdamW` optimizer.
- The trainer can load model-only checkpoints now.
- Added the option to throttle checkpoint uploads to one rank from each node at a time.
Expand Down
4 changes: 3 additions & 1 deletion src/olmo_core/data/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,10 @@ def __init__(
self.collator = collator
self.work_dir = work_dir
self.global_batch_size = global_batch_size
assert dp_rank < dp_world_size
self.dp_world_size = dp_world_size
self.dp_rank = dp_rank

self.fs_local_rank = fs_local_rank

self.batches_processed = 0
Expand Down Expand Up @@ -432,7 +434,7 @@ def reshuffle(self, epoch: Optional[int] = None, in_memory: bool = False, **kwar
self.build_and_save_global_indices(in_memory=in_memory)

def get_mock_batch(self) -> Dict[str, Any]:
rng = torch.Generator()
rng = torch.Generator(device=get_default_device())
rng.manual_seed(self.seed + self.dp_rank)
num_instances = self.rank_batch_size // self.dataset.max_sequence_length
input_ids = torch.randint(
Expand Down
2 changes: 1 addition & 1 deletion src/olmo_core/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def get_world_size(group: Optional[dist.ProcessGroup] = None) -> int:
if is_distributed():
return dist.get_world_size(group)
else:
return 0
return 1


def get_local_world_size() -> int:
Expand Down
2 changes: 2 additions & 0 deletions src/olmo_core/internal/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def get_root_dir(cluster: str) -> str:
root_dir = "/weka/oe-training-default/ai2-llm"
elif "augusta" in cluster:
root_dir = "gs://ai2-llm"
elif "local" in cluster:
root_dir = "gs://ai2-llm"
return root_dir


Expand Down
25 changes: 24 additions & 1 deletion src/olmo_core/internal/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, cast

import torch
from rich import print

from olmo_core.config import Config, StrEnum
Expand Down Expand Up @@ -72,6 +73,7 @@ class ExperimentConfig(Config):
class SubCmd(StrEnum):
launch = "launch"
train = "train"
train_single = "train_single"
prep = "prep"
launch_prep = "launch_prep"
dry_run = "dry_run"
Expand All @@ -81,6 +83,8 @@ def prepare_environment(self):
prepare_cli_environment()
elif self == SubCmd.train:
prepare_training_environment()
elif self == SubCmd.train_single:
prepare_training_environment(backend=None)
else:
raise NotImplementedError(self)

Expand All @@ -102,6 +106,23 @@ def run(self, config: ExperimentConfig):
train(config)
finally:
teardown_training_environment()
elif self == SubCmd.train_single:
if config.model.dp_config is not None:
log.warning(
"dp_config is set to %s, but you can't use data parallelism when running on a single node. Disabling.",
config.model.dp_config,
)
config.model.dp_config = None
if config.model.tp_config is not None:
log.warning(
"tp_config is set to %s, but you can't use tensor parallelism when running on a single node. Disabling.",
config.model.dp_config,
)
config.model.tp_config = None
try:
train(config)
finally:
teardown_training_environment()
elif self == SubCmd.prep:
prep(config)
elif self == SubCmd.launch_prep:
Expand Down Expand Up @@ -157,7 +178,6 @@ def build_common_components(

callbacks: Dict[str, Callback] = {
"lr_scheduler": SchedulerCallback(scheduler=CosWithWarmup(warmup_steps=2000)),
"gpu_monitor": GPUMemoryMonitorCallback(),
"grad_clipper": GradClipperCallback(max_grad_norm=1.0),
"config_saver": ConfigSaverCallback(),
"profiler": ProfilerCallback(enabled=False),
Expand All @@ -175,6 +195,8 @@ def build_common_components(
),
"slack_notifier": SlackNotifierCallback(name=run_name, enabled=False),
}
if torch.cuda.is_available():
callbacks["gpu_monitor"] = GPUMemoryMonitorCallback()

return CommonComponents(
run_name=run_name,
Expand Down Expand Up @@ -306,6 +328,7 @@ def main(
[b magenta]launch:[/] Launch the script on Beaker with the [b magenta]train[/] subcommand.
[b magenta]train:[/] Run the trainer. You usually shouldn't invoke the script with this subcommand directly.
Instead use [b magenta]launch[/] or run it with torchrun.
[b magenta]train_single:[/] Run the trainer on a single device (GPU, CPU, MPS). num_nodes is ignored.
[b magenta]prep:[/] Prepare the dataset ahead of training to save GPU time.
[b magenta]launch_prep:[/] Launch the script on Beaker with the [b magenta]prep[/] subcommand.
[b magenta]dry_run:[/] Pretty print the config and exit.
Expand Down
7 changes: 6 additions & 1 deletion src/olmo_core/nn/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,12 @@ def build(

# Maybe compile.
if self.compile:
model.apply_compile()
if torch.cuda.is_available():
model.apply_compile()
else:
log.warning(
"model.compile was set to True, but CUDA is not available. Compiling only works with CUDA. Ignoring."
)

# Maybe wrap for data parallel.
if dp_mesh is None and mesh is not None:
Expand Down
5 changes: 4 additions & 1 deletion src/olmo_core/train/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,13 @@
from datetime import timedelta
from typing import Optional

import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from ..distributed.utils import init_distributed, is_distributed
from ..io import add_cached_path_clients
from ..utils import LogFilterType, prepare_cli_environment, seed_all
from ..utils import LogFilterType, get_default_device, prepare_cli_environment, seed_all
from .checkpoint import Checkpointer, CheckpointerConfig
from .common import Duration, DurationUnit, LoadStrategy, ReduceType
from .config import TrainerConfig
Expand Down Expand Up @@ -117,6 +118,8 @@ def prepare_training_environment(
# Initialize process group.
if backend is not None:
init_distributed(backend=backend, timeout=timeout)
else:
torch.set_default_device(get_default_device())

# Configure logging, warning filters, exception hooks, and other CLI settings.
prepare_cli_environment(log_filter_type=log_filter_type)
Expand Down
10 changes: 8 additions & 2 deletions src/olmo_core/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,12 @@ def __post_init__(self):
else:
self._loss_fn = cross_entropy_loss
if self.compile_loss:
self._loss_fn = torch.compile(self._loss_fn)
if torch.cuda.is_available():
self._loss_fn = torch.compile(self._loss_fn)
else:
log.warning(
"compile_loss was set to True, but CUDA is not available. Compiling only works with CUDA. Ignoring."
)

@property
def global_batch_size(self) -> int:
Expand Down Expand Up @@ -1324,7 +1329,8 @@ def _fit_epoch(self):

if first_batch or self.global_step % self.metrics_collect_interval == 0:
self._log_metrics()
torch.cuda.set_sync_debug_mode("warn")
if torch.cuda.is_available():
torch.cuda.set_sync_debug_mode("warn")

first_batch = False

Expand Down
2 changes: 2 additions & 0 deletions src/olmo_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ def get_default_device() -> torch.device:
"""
if torch.cuda.is_available() and torch.cuda.is_initialized():
return torch.device("cuda")
elif torch.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")

Expand Down
45 changes: 0 additions & 45 deletions src/scripts/train/OLMo2-1B.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
from olmo_core.optim import AdamWConfig, OptimGroupOverride
from olmo_core.train import TrainerConfig
from olmo_core.train.callbacks import CheckpointerCallback, CometCallback, WandBCallback
from olmo_core.train.callbacks.evaluator_callback import (
DownstreamEvaluatorCallbackConfig,
)


def build_model_config(common: CommonComponents) -> TransformerConfig:
Expand Down Expand Up @@ -76,48 +73,6 @@ def build_trainer_config(common: CommonComponents) -> TrainerConfig:
cancel_check_interval=10,
),
)
.with_callback(
"downstream_evaluator",
DownstreamEvaluatorCallbackConfig(
tasks=[
"arc_challenge_val_rc_5shot",
"arc_challenge_val_mc_5shot",
"arc_challenge_test_rc_5shot",
"arc_challenge_test_mc_5shot",
"arc_easy_val_rc_5shot",
"arc_easy_val_mc_5shot",
"arc_easy_test_rc_5shot",
"arc_easy_test_mc_5shot",
"boolq_val_rc_5shot",
"boolq_val_mc_5shot",
"csqa_val_rc_5shot",
"csqa_val_mc_5shot",
"hellaswag_val_rc_5shot",
"hellaswag_val_mc_5shot",
"openbookqa_val_rc_5shot",
"openbookqa_val_mc_5shot",
"openbookqa_test_rc_5shot",
"openbookqa_test_mc_5shot",
"piqa_val_rc_5shot",
"piqa_val_mc_5shot",
"socialiqa_val_rc_5shot",
"socialiqa_val_mc_5shot",
"winogrande_val_rc_5shot",
"winogrande_val_mc_5shot",
"mmlu_stem_val_rc_5shot",
"mmlu_stem_val_mc_5shot",
"mmlu_humanities_val_rc_5shot",
"mmlu_humanities_val_mc_5shot",
"mmlu_social_sciences_val_rc_5shot",
"mmlu_social_sciences_val_mc_5shot",
"mmlu_other_val_rc_5shot",
"mmlu_other_val_mc_5shot",
],
tokenizer=common.tokenizer,
eval_batch_size=1024 * 4096,
eval_interval=1000,
),
)
)


Expand Down

0 comments on commit 075a36a

Please sign in to comment.