Skip to content

Commit

Permalink
allow for missing deterministic key on config
Browse files Browse the repository at this point in the history
for backwards compatibility
  • Loading branch information
henryaddison committed Aug 9, 2024
1 parent 468fa7f commit bb37ccd
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
3 changes: 2 additions & 1 deletion bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def _init_state(config):


def load_model(config, ckpt_filename):
if config.deterministic:
deterministic = "deterministic" in config and config.deterministic
if deterministic:
sde = None
sampling_eps = 0
else:
Expand Down
7 changes: 4 additions & 3 deletions src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def train(config, workdir):
initial_epoch = int(state['epoch'])+1 # start from the epoch after the one currently reached

# Setup SDEs
if config.deterministic:
deterministic = "deterministic" in config and config.deterministic
if deterministic:
sde = None
else:
if config.training.sde.lower() == 'vpsde':
Expand All @@ -163,11 +164,11 @@ def train(config, workdir):
train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn,
reduce_mean=reduce_mean, continuous=continuous,
likelihood_weighting=likelihood_weighting,
deterministic=config.deterministic,)
deterministic=deterministic,)
eval_step_fn = losses.get_step_fn(sde, train=False, optimize_fn=optimize_fn,
reduce_mean=reduce_mean, continuous=continuous,
likelihood_weighting=likelihood_weighting,
deterministic=config.deterministic,)
deterministic=deterministic,)

num_train_epochs = config.training.n_epochs

Expand Down
2 changes: 1 addition & 1 deletion src/ml_downscaling_emulator/score_sde_pytorch/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_sampling_fn(config, sde, shape, eps):
trailing dimensions matching `shape`.
"""

if config.deterministic:
if "deterministic" in config and config.deterministic:
sampling_fn = get_deterministic_sampler(shape, device=config.device)
else:
sampler_name = config.sampling.method
Expand Down

0 comments on commit bb37ccd

Please sign in to comment.