diff --git a/src/olmo_core/train/checkpoint.py b/src/olmo_core/train/checkpoint.py index 924c0bee..1c260098 100644 --- a/src/olmo_core/train/checkpoint.py +++ b/src/olmo_core/train/checkpoint.py @@ -193,14 +193,15 @@ def load( if get_rank(self.process_group) == 0: try: metadata = get_checkpoint_metadata(model_and_optim_dir) - except FileNotFoundError: + except FileNotFoundError as exc: # Try base directory, which could be the case if user is trying to load model weights # (possibly with optimizer state), and not an actual train checkpoint. if trainer_state is None: metadata = get_checkpoint_metadata(dir) model_and_optim_dir = dir else: - raise + raise FileNotFoundError(f"Missing checkpointing metadata in '{dir}'") from exc + if load_optimizer_state is None: for key in metadata.state_dict_metadata.keys(): if key.startswith("optim."):