Skip to content

Commit

Permalink
Merge pull request #34 from henryaddison/stabilize-val-loss
Browse files Browse the repository at this point in the history
Stabilize val loss
  • Loading branch information
henryaddison authored Aug 3, 2024
2 parents f167214 + 6ea7640 commit 5bf414d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 14 deletions.
16 changes: 11 additions & 5 deletions src/ml_downscaling_emulator/score_sde_pytorch/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_we
"""
reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)

def loss_fn(model, batch, cond):
def loss_fn(model, batch, cond, generator=None):
"""Compute the loss function.
Args:
Expand All @@ -86,8 +86,13 @@ def loss_fn(model, batch, cond):
loss: A scalar that represents the average loss value across the mini-batch.
"""
score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
z = torch.randn_like(batch)

if train:
t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
z = torch.randn_like(batch)
else:
t = torch.rand(batch.shape[0], device=batch.device, generator=generator) * (sde.T - eps) + eps
z = torch.empty_like(batch).normal_(generator=generator)
mean, std = sde.marginal_prob(batch, t)
perturbed_data = mean + std[:, None, None, None] * z
score = score_fn(perturbed_data, cond, t)
Expand Down Expand Up @@ -179,7 +184,7 @@ def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True
else:
raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")

def step_fn(state, batch, cond):
def step_fn(state, batch, cond, generator=None):
"""Running one step of training or evaluation.
This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
Expand All @@ -190,6 +195,7 @@ def step_fn(state, batch, cond):
EMA status, and number of optimization steps.
batch: A mini-batch of training/evaluation data to model.
cond: A mini-batch of conditioning inputs.
generator: An optional random number generator so can control the timesteps and initial noise samples used by loss function [ignored in train mode]
Returns:
loss: The average loss value of this state.
Expand All @@ -208,7 +214,7 @@ def step_fn(state, batch, cond):
ema = state['ema']
ema.store(model.parameters())
ema.copy_to(model.parameters())
loss = loss_fn(model, batch, cond)
loss = loss_fn(model, batch, cond, generator=generator)
ema.restore(model.parameters())

return loss
Expand Down
20 changes: 11 additions & 9 deletions src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,20 +52,24 @@

EXPERIMENT_NAME = os.getenv("WANDB_EXPERIMENT_NAME")

def val_loss(config, eval_ds, eval_step_fn, state):
def val_loss(config, eval_dl, eval_step_fn, state):
val_set_loss = 0.0
for eval_cond_batch, eval_target_batch, eval_time_batch in eval_ds:
# eval_cond_batch, eval_target_batch = next(iter(eval_ds))
# use a consistent generator for computing validation set loss
# so value is not down to vagaries of random choice of initial noise samples or schedules
g = torch.Generator(device=config.device)
g.manual_seed(42)
for eval_cond_batch, eval_target_batch, eval_time_batch in eval_dl:
# eval_cond_batch, eval_target_batch = next(iter(eval_dl))
eval_target_batch = eval_target_batch.to(config.device)
eval_cond_batch = eval_cond_batch.to(config.device)
# append any location-specific parameters
eval_cond_batch = state['location_params'](eval_cond_batch)
# eval_batch = eval_batch.permute(0, 3, 1, 2)
eval_loss = eval_step_fn(state, eval_target_batch, eval_cond_batch)
eval_loss = eval_step_fn(state, eval_target_batch, eval_cond_batch, generator=g)

# Progress
val_set_loss += eval_loss.item()
val_set_loss = val_set_loss/len(eval_ds)
val_set_loss = val_set_loss/len(eval_dl)

return val_set_loss

Expand Down Expand Up @@ -112,7 +116,7 @@ def train(config, workdir):
# Build dataloaders
dataset_meta = DatasetMetadata(config.data.dataset_name)
train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False)
eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False)
eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False, shuffle=False)

# Initialize model.
score_model = mutils.create_model(config)
Expand Down Expand Up @@ -208,9 +212,7 @@ def train(config, workdir):
val_set_loss = val_loss(config, eval_dl, eval_step_fn, state)
epoch_metrics = {"epoch/train/loss": train_set_loss, "epoch/val/loss": val_set_loss}

logging.info("epoch: %d, val_loss: %.5e" % (state['epoch'], val_set_loss))
writer.add_scalar("epoch/val/loss", val_set_loss, global_step=state['epoch'])
log_epoch(state['epoch'], epoch_metrics, wandb_run,writer)
log_epoch(state['epoch'], epoch_metrics, wandb_run, writer)

if (state['epoch'] != 0 and state['epoch'] % config.training.snapshot_freq == 0) or state['epoch'] == num_train_epochs:
# Save the checkpoint.
Expand Down

0 comments on commit 5bf414d

Please sign in to comment.