Skip to content

Commit

Permalink
removed code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-sk committed Dec 21, 2024
1 parent bb9bd69 commit 4b32b63
Showing 1 changed file with 3 additions and 16 deletions.
19 changes: 3 additions & 16 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
cfg.reset_optimizer_state = False

if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None:
if cfg.distributed_strategy == DistributedStrategy.ddp:
if cfg.distributed_strategy in [DistributedStrategy.ddp, DistributedStrategy.single]:
checkpoint_type = CheckpointType.unsharded

if cfg.save_interval_unsharded is None:
Expand All @@ -312,21 +312,8 @@ def dummy_init_fn(module: torch.nn.Module) -> None:
checkpoint_type = (
CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded
)
elif cfg.distributed_strategy == DistributedStrategy.single:
#raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!")
checkpoint_type = CheckpointType.unsharded

if cfg.save_interval_unsharded is None:
log.warning(
"single accelerator training requires setting `save_interval_unsharded`. Using the value set for `save_interval`."
)
cfg.save_interval_unsharded = cfg.save_interval

if cfg.save_num_unsharded_checkpoints_to_keep == 0:
log.warning(
"single accelerator training requires setting `save_num_unsharded_checkpoints_to_keep`. Using the value set for `save_num_checkpoints_to_keep`."
)
cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep
else:
raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!")

# We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever).
log.info("Saving pre-train checkpoint...")
Expand Down

0 comments on commit 4b32b63

Please sign in to comment.