diff --git a/bin/predict.py b/bin/predict.py index 18b1b4828..169e07bb6 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -98,7 +98,8 @@ def _init_state(config): def load_model(config, ckpt_filename): - if config.deterministic: + deterministic = "deterministic" in config and config.deterministic + if deterministic: sde = None sampling_eps = 0 else: diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index e185a675d..14c398cba 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -140,7 +140,8 @@ def train(config, workdir): initial_epoch = int(state['epoch'])+1 # start from the epoch after the one currently reached # Setup SDEs - if config.deterministic: + deterministic = "deterministic" in config and config.deterministic + if deterministic: sde = None else: if config.training.sde.lower() == 'vpsde': @@ -163,11 +164,11 @@ def train(config, workdir): train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting, - deterministic=config.deterministic,) + deterministic=deterministic,) eval_step_fn = losses.get_step_fn(sde, train=False, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting, - deterministic=config.deterministic,) + deterministic=deterministic,) num_train_epochs = config.training.n_epochs diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py b/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py index 8ee38c919..556f0328f 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py @@ -95,7 +95,7 @@ def get_sampling_fn(config, sde, shape, eps): trailing dimensions matching `shape`. """ - if config.deterministic: + if "deterministic" in config and config.deterministic: sampling_fn = get_deterministic_sampler(shape, device=config.device) else: sampler_name = config.sampling.method