From bb9bd69646eebbc8d154dbf2646c1cc9683809c2 Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sat, 21 Dec 2024 08:42:31 +0100 Subject: [PATCH 1/5] enabling single accelerator training (e.g. for the MPS backend on Macs) --- olmo/checkpoint.py | 12 ++++++---- olmo/config.py | 4 ++++ olmo/torch_util.py | 8 +++++++ olmo/train.py | 8 +++++-- scripts/train.py | 59 ++++++++++++++++++++++++++++++++-------------- 5 files changed, 66 insertions(+), 25 deletions(-) diff --git a/olmo/checkpoint.py b/olmo/checkpoint.py index 7e5e7a137..6aace690f 100644 --- a/olmo/checkpoint.py +++ b/olmo/checkpoint.py @@ -52,6 +52,7 @@ from .exceptions import OLMoCheckpointError from .optim import Optimizer, fix_optim_state_dict from .safetensors_util import safetensors_file_to_state_dict +from .torch_util import SingleAccelerator as SINGLE from .torch_util import ( barrier, gc_cuda, @@ -645,7 +646,7 @@ def save_checkpoint( self._write_optim_dict( optim_state_dict, checkpoint_dir, upload_to, save_overwrite=self.cfg.save_overwrite ) - elif isinstance(dist_model, DDP): + elif isinstance(dist_model, DDP) or isinstance(dist_model, SINGLE): # _write_model_dict and _write_optim_dict only write checkpoints for rank 0 # First, get the model state dict from DDP wrapped model model_state_dict = dist_model.module.state_dict() @@ -660,7 +661,7 @@ def save_checkpoint( ) else: log.info( - "`FullCheckpointer.save_checkpoint` only supported for FSDP and DDP distributed strategies!" + "`FullCheckpointer.save_checkpoint` only supported for FSDP, DDP, and SINGLE distributed strategies!" ) # Save trainer state. @@ -757,7 +758,7 @@ def restore_checkpoint( torch.cuda.empty_cache() barrier() del optim_state_dict_to_load - elif isinstance(dist_model, DDP): + elif isinstance(dist_model, DDP) or isinstance(dist_model, SINGLE): # Load model state. with torch.no_grad(): state_dict_to_load = load_state_dict( @@ -773,11 +774,12 @@ def restore_checkpoint( optim.load_state_dict(optim_state_dict_to_load) gc.collect() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() barrier() else: raise NotImplementedError( - "`FullCheckpointer.restore_checkpoint` only supported for FSDP and DDP distributed strategies!" + "`FullCheckpointer.restore_checkpoint` only supported for FSDP, DDP, and SINGLE distributed strategies!" ) # Load other state. diff --git a/olmo/config.py b/olmo/config.py index 4e197a341..91f744745 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -719,6 +719,10 @@ class DistributedStrategy(StrEnum): Wrap OLMo in torch.distributed.fsdp.FullyShardedDataParallel to train across ranks. """ + single = "single" + """ + Train on a single device, i.e., do not distribute trainig. For development and debugging. + """ class DDPGradSyncMode(StrEnum): batch = "batch" diff --git a/olmo/torch_util.py b/olmo/torch_util.py index 0aa52961e..b6f3e5bd3 100644 --- a/olmo/torch_util.py +++ b/olmo/torch_util.py @@ -156,3 +156,11 @@ def get_cumulative_document_lengths(doc_lens: torch.Tensor) -> torch.Tensor: torch.cumsum(doc_lens.masked_select(doc_lens != 0), 0, dtype=torch.int32), ] ) + +class SingleAccelerator(torch.nn.Module): + process_group = None + def __init__(self, module: torch.nn.Module): + super().__init__() + self.module = module + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) diff --git a/olmo/train.py b/olmo/train.py index 105f82e40..75bef2aea 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -332,7 +332,8 @@ def trainer_state_dict(self) -> Dict[str, Any]: "python": random.getstate(), "numpy": np.random.get_state(), "torch": torch.random.get_rng_state(), - "cuda": torch.cuda.get_rng_state(), + "cuda": torch.cuda.get_rng_state() if torch.cuda.is_available() else None, + "mps": torch.mps.get_rng_state() if torch.mps.is_available() else None, }, } @@ -430,7 +431,10 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None: random.setstate(rng_state["python"]) np.random.set_state(rng_state["numpy"]) torch.set_rng_state(rng_state["torch"]) - torch.cuda.set_rng_state(rng_state["cuda"]) + if rng_state["cuda"] is not None: + torch.cuda.set_rng_state(rng_state["cuda"]) + if rng_state["mps"] is not None: + torch.mps.set_rng_state(rng_state["mps"]) def _save_checkpoint( self, checkpointer: Checkpointer, checkpoint_type: CheckpointType diff --git a/scripts/train.py b/scripts/train.py index b4d89be2d..62fb1050e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -27,6 +27,7 @@ from olmo.exceptions import OLMoCliError, OLMoConfigurationError from olmo.model import OLMo from olmo.optim import BoltOnWarmupScheduler, build_optimizer, build_scheduler +from olmo.torch_util import SingleAccelerator as SINGLE from olmo.torch_util import ( barrier, get_default_device, @@ -65,9 +66,14 @@ def main(cfg: TrainConfig) -> None: barrier() # Set CUDA device. - torch.cuda.set_device(f"cuda:{get_local_rank()}") - torch.cuda.empty_cache() - device = torch.device("cuda") + if torch.cuda.is_available(): + torch.cuda.set_device(f"cuda:{get_local_rank()}") + torch.cuda.empty_cache() + device = torch.device("cuda") + elif torch.mps.is_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") # Fill some configuration options. cfg.model.precision = cfg.precision @@ -211,8 +217,9 @@ def dummy_init_fn(module: torch.nn.Module) -> None: param_init_fn=param_init_fn, **hybrid_sharding_fsdp_kwargs, ) - elif cfg.distributed_strategy is None: - raise NotImplementedError("Single accelerator training not implemented yet!") + elif cfg.distributed_strategy == DistributedStrategy.single: + param_init_fn = None + dist_model = SINGLE(olmo_model) # when param_init_fn is None, FSDP will call reset_parameters() automatically if param_init_fn is not None or cfg.distributed_strategy == DistributedStrategy.ddp: @@ -305,8 +312,21 @@ def dummy_init_fn(module: torch.nn.Module) -> None: checkpoint_type = ( CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded ) - else: - raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!") + elif cfg.distributed_strategy == DistributedStrategy.single: + #raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!") + checkpoint_type = CheckpointType.unsharded + + if cfg.save_interval_unsharded is None: + log.warning( + "single accelerator training requires setting `save_interval_unsharded`. Using the value set for `save_interval`." + ) + cfg.save_interval_unsharded = cfg.save_interval + + if cfg.save_num_unsharded_checkpoints_to_keep == 0: + log.warning( + "single accelerator training requires setting `save_num_unsharded_checkpoints_to_keep`. Using the value set for `save_num_checkpoints_to_keep`." + ) + cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep # We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever). log.info("Saving pre-train checkpoint...") @@ -363,17 +383,20 @@ def dummy_init_fn(module: torch.nn.Module) -> None: print(f"failed to set multiprocessing start method: {e}") log.info(f"Multiprocessing start method set to '{mp.get_start_method()}'") - # Set CUDA device. - torch.cuda.set_device(f"cuda:{get_local_rank()}") - - # Initialize process group. - device_as_string = f"cuda:{get_local_rank()}" - torch.cuda.set_device( - device_as_string - ) # Set this early to prevent GPU 0 from picking up a bunch of tensors it shouldn't have. - dist.init_process_group( - backend="nccl", timeout=timedelta(minutes=30), device_id=torch.device(device_as_string) - ) + if torch.cuda.is_available(): + # Set CUDA device. + torch.cuda.set_device(f"cuda:{get_local_rank()}") + + # Initialize process group. + device_as_string = f"cuda:{get_local_rank()}" + torch.cuda.set_device( + device_as_string + ) # Set this early to prevent GPU 0 from picking up a bunch of tensors it shouldn't have. + dist.init_process_group( + backend="nccl", timeout=timedelta(minutes=30), device_id=torch.device(device_as_string) + ) + else: + dist.init_process_group(backend="gloo", timeout=timedelta(minutes=30)) log.info("Process group initialized") prepare_cli_environment() From 4b32b6340ab13088a0e2d8c9d4af87da379eecf0 Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sat, 21 Dec 2024 10:21:38 +0100 Subject: [PATCH 2/5] removed code duplication --- scripts/train.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/scripts/train.py b/scripts/train.py index 62fb1050e..62c066970 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -294,7 +294,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: cfg.reset_optimizer_state = False if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None: - if cfg.distributed_strategy == DistributedStrategy.ddp: + if cfg.distributed_strategy in [DistributedStrategy.ddp, DistributedStrategy.single]: checkpoint_type = CheckpointType.unsharded if cfg.save_interval_unsharded is None: @@ -312,21 +312,8 @@ def dummy_init_fn(module: torch.nn.Module) -> None: checkpoint_type = ( CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded ) - elif cfg.distributed_strategy == DistributedStrategy.single: - #raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!") - checkpoint_type = CheckpointType.unsharded - - if cfg.save_interval_unsharded is None: - log.warning( - "single accelerator training requires setting `save_interval_unsharded`. Using the value set for `save_interval`." - ) - cfg.save_interval_unsharded = cfg.save_interval - - if cfg.save_num_unsharded_checkpoints_to_keep == 0: - log.warning( - "single accelerator training requires setting `save_num_unsharded_checkpoints_to_keep`. Using the value set for `save_num_checkpoints_to_keep`." - ) - cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep + else: + raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!") # We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever). log.info("Saving pre-train checkpoint...") From 311286cacbee13bc4ef19ad7a6b389497237c2e7 Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sat, 21 Dec 2024 10:31:11 +0100 Subject: [PATCH 3/5] backward compatibility for checkpoints --- olmo/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olmo/train.py b/olmo/train.py index 75bef2aea..4c1f3b774 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -431,9 +431,9 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None: random.setstate(rng_state["python"]) np.random.set_state(rng_state["numpy"]) torch.set_rng_state(rng_state["torch"]) - if rng_state["cuda"] is not None: + if rng_state.get("cuda", None) is not None: torch.cuda.set_rng_state(rng_state["cuda"]) - if rng_state["mps"] is not None: + if rng_state.get("mps", None) is not None: torch.mps.set_rng_state(rng_state["mps"]) def _save_checkpoint( From 539f64a91e812901da4ab6648c0237aa0ca0e94c Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sat, 21 Dec 2024 11:15:43 +0100 Subject: [PATCH 4/5] reversed logic to ensure checkpointing is unsharded for single accelerator --- olmo/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/olmo/train.py b/olmo/train.py index 4c1f3b774..a7b5426ae 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -1251,7 +1251,7 @@ def on_trace_ready(p): stop_at = min(stop_at, self.global_step + extra_steps) # Maybe save sharded checkpoint. - if self.cfg.distributed_strategy != DistributedStrategy.ddp: + if self.cfg.distributed_strategy == DistributedStrategy.fsdp: if save_checkpoints and ( cancel_initiated or ( From d8f68ea8eed97de7a8efcea80c49019279b3666f Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sat, 21 Dec 2024 15:17:22 +0100 Subject: [PATCH 5/5] should probably do this --- scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/train.py b/scripts/train.py index 62c066970..cc2ad4c0b 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -219,7 +219,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: ) elif cfg.distributed_strategy == DistributedStrategy.single: param_init_fn = None - dist_model = SINGLE(olmo_model) + dist_model = SINGLE(olmo_model.to(device)) # when param_init_fn is None, FSDP will call reset_parameters() automatically if param_init_fn is not None or cfg.distributed_strategy == DistributedStrategy.ddp: