From b7d54de886aecacf02e05f5d6119bb85aeba399a Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 13 Feb 2024 21:58:12 +0000 Subject: [PATCH] correct an update to predict.py --- bin/predict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/predict.py b/bin/predict.py index 514262ef0..df6edbbb9 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -120,7 +120,7 @@ def load_model(config, ckpt_filename): state = _init_state(config) state, loaded = restore_checkpoint(ckpt_filename, state, config.device) assert loaded, "Did not load state from checkpoint" - state["ema"].copy_to(state["score_model"].parameters()) + state["ema"].copy_to(state["model"].parameters()) # Sampling num_output_channels = len(get_variables(config.data.dataset_name)[1])