From 311286cacbee13bc4ef19ad7a6b389497237c2e7 Mon Sep 17 00:00:00 2001 From: Peter Schneider-Kamp Date: Sat, 21 Dec 2024 10:31:11 +0100 Subject: [PATCH] backward compatibility for checkpoints --- olmo/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/olmo/train.py b/olmo/train.py index 75bef2aea..4c1f3b774 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -431,9 +431,9 @@ def restore_rng_state(self, rng_state: Dict[str, Any]) -> None: random.setstate(rng_state["python"]) np.random.set_state(rng_state["numpy"]) torch.set_rng_state(rng_state["torch"]) - if rng_state["cuda"] is not None: + if rng_state.get("cuda", None) is not None: torch.cuda.set_rng_state(rng_state["cuda"]) - if rng_state["mps"] is not None: + if rng_state.get("mps", None) is not None: torch.mps.set_rng_state(rng_state["mps"]) def _save_checkpoint(