diff --git a/olmo/train.py b/olmo/train.py index 75bef2aea..4c1f3b774 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -431,9 +431,9 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None: random.setstate(rng_state["python"]) np.random.set_state(rng_state["numpy"]) torch.set_rng_state(rng_state["torch"]) - if rng_state["cuda"] is not None: + if rng_state.get("cuda", None) is not None: torch.cuda.set_rng_state(rng_state["cuda"]) - if rng_state["mps"] is not None: + if rng_state.get("mps", None) is not None: torch.mps.set_rng_state(rng_state["mps"]) def _save_checkpoint(