diff --git a/bin/predict.py b/bin/predict.py index adcc09d71..c6eb23da1 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -239,7 +239,7 @@ def main( workdir: Path, dataset: str = typer.Option(...), split: str = "val", - epoch: int = typer.Option(...), + checkpoint: str = typer.Option(...), batch_size: int = None, num_samples: int = 3, input_transform_dataset: str = None, @@ -260,7 +260,7 @@ def main( output_dirpath = samples_path( workdir=workdir, - checkpoint=f"epoch-{epoch}", + checkpoint=checkpoint, dataset=dataset, input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}", split=split, @@ -290,7 +290,7 @@ def main( shuffle=False, ) - ckpt_filename = os.path.join(workdir, "checkpoints", f"epoch_{epoch}.pth") + ckpt_filename = os.path.join(workdir, "checkpoints", f"{checkpoint}.pth") logger.info(f"Loading model from {ckpt_filename}") state, sampling_fn = load_model(config, ckpt_filename)