diff --git a/olmo/train.py b/olmo/train.py index 4c1f3b774..a7b5426ae 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -1251,7 +1251,7 @@ def on_trace_ready(p): stop_at = min(stop_at, self.global_step + extra_steps) # Maybe save sharded checkpoint. - if self.cfg.distributed_strategy != DistributedStrategy.ddp: + if self.cfg.distributed_strategy == DistributedStrategy.fsdp: if save_checkpoints and ( cancel_initiated or (