diff --git a/scripts/train.py b/scripts/train.py index 62fb1050e..62c066970 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -294,7 +294,7 @@ def dummy_init_fn(module: torch.nn.Module) -> None: cfg.reset_optimizer_state = False if not cfg.dry_run and not cfg.no_pre_train_checkpoint and cfg.load_path is None: - if cfg.distributed_strategy == DistributedStrategy.ddp: + if cfg.distributed_strategy in [DistributedStrategy.ddp, DistributedStrategy.single]: checkpoint_type = CheckpointType.unsharded if cfg.save_interval_unsharded is None: @@ -312,21 +312,8 @@ def dummy_init_fn(module: torch.nn.Module) -> None: checkpoint_type = ( CheckpointType.sharded if cfg.save_num_checkpoints_to_keep != 0 else CheckpointType.unsharded ) - elif cfg.distributed_strategy == DistributedStrategy.single: - #raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!") - checkpoint_type = CheckpointType.unsharded - - if cfg.save_interval_unsharded is None: - log.warning( - "single accelerator training requires setting `save_interval_unsharded`. Using the value set for `save_interval`." - ) - cfg.save_interval_unsharded = cfg.save_interval - - if cfg.save_num_unsharded_checkpoints_to_keep == 0: - log.warning( - "single accelerator training requires setting `save_num_unsharded_checkpoints_to_keep`. Using the value set for `save_num_checkpoints_to_keep`." - ) - cfg.save_num_unsharded_checkpoints_to_keep = cfg.save_num_checkpoints_to_keep + else: + raise NotImplementedError(f"Distributed strategy {cfg.distributed_strategy} not supported yet!") # We save a checkpoint up-front to make sure this won't fail (due to disk space or whatever). log.info("Saving pre-train checkpoint...")