Skip to content

Commit

Permalink
use checkpoint rather than epoch for determining checkpoint for sampling
Browse files Browse the repository at this point in the history
just cleaner and more consistent
  • Loading branch information
henryaddison committed Mar 18, 2024
1 parent 4a22a06 commit 354a641
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def main(
workdir: Path,
dataset: str = typer.Option(...),
split: str = "val",
epoch: int = typer.Option(...),
checkpoint: str = typer.Option(...),
batch_size: int = None,
num_samples: int = 3,
input_transform_dataset: str = None,
Expand All @@ -260,7 +260,7 @@ def main(

output_dirpath = samples_path(
workdir=workdir,
checkpoint=f"epoch-{epoch}",
checkpoint=checkpoint,
dataset=dataset,
input_xfm=f"{config.data.input_transform_dataset}-{config.data.input_transform_key}",
split=split,
Expand Down Expand Up @@ -290,7 +290,7 @@ def main(
shuffle=False,
)

ckpt_filename = os.path.join(workdir, "checkpoints", f"epoch_{epoch}.pth")
ckpt_filename = os.path.join(workdir, "checkpoints", f"{checkpoint}.pth")
logger.info(f"Loading model from {ckpt_filename}")
state, sampling_fn = load_model(config, ckpt_filename)

Expand Down

0 comments on commit 354a641

Please sign in to comment.