diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/losses.py b/src/ml_downscaling_emulator/score_sde_pytorch/losses.py index 19b76bd95..a0adc231b 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/losses.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/losses.py @@ -74,7 +74,7 @@ def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_we """ reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) - def loss_fn(model, batch, cond): + def loss_fn(model, batch, cond, generator=None): """Compute the loss function. Args: @@ -86,8 +86,13 @@ def loss_fn(model, batch, cond): loss: A scalar that represents the average loss value across the mini-batch. """ score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous) - t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps - z = torch.randn_like(batch) + + if train: + t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps + z = torch.randn_like(batch) + else: + t = torch.rand(batch.shape[0], device=batch.device, generator=generator) * (sde.T - eps) + eps + z = torch.empty_like(batch).normal_(generator=generator) mean, std = sde.marginal_prob(batch, t) perturbed_data = mean + std[:, None, None, None] * z score = score_fn(perturbed_data, cond, t) @@ -179,7 +184,7 @@ def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True else: raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") - def step_fn(state, batch, cond): + def step_fn(state, batch, cond, generator=None): """Running one step of training or evaluation. This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together @@ -190,6 +195,7 @@ def step_fn(state, batch, cond): EMA status, and number of optimization steps. batch: A mini-batch of training/evaluation data to model. cond: A mini-batch of conditioning inputs. + generator: An optional random number generator so can control the timesteps and initial noise samples used by loss function [ignored in train mode] Returns: loss: The average loss value of this state. @@ -208,7 +214,7 @@ def step_fn(state, batch, cond): ema = state['ema'] ema.store(model.parameters()) ema.copy_to(model.parameters()) - loss = loss_fn(model, batch, cond) + loss = loss_fn(model, batch, cond, generator=generator) ema.restore(model.parameters()) return loss diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 998df6920..986a03865 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -52,20 +52,24 @@ EXPERIMENT_NAME = os.getenv("WANDB_EXPERIMENT_NAME") -def val_loss(config, eval_ds, eval_step_fn, state): +def val_loss(config, eval_dl, eval_step_fn, state): val_set_loss = 0.0 - for eval_cond_batch, eval_target_batch, eval_time_batch in eval_ds: - # eval_cond_batch, eval_target_batch = next(iter(eval_ds)) + # use a consistent generator for computing validation set loss + # so value is not down to vagaries of random choice of initial noise samples or schedules + g = torch.Generator(device=config.device) + g.manual_seed(42) + for eval_cond_batch, eval_target_batch, eval_time_batch in eval_dl: + # eval_cond_batch, eval_target_batch = next(iter(eval_dl)) eval_target_batch = eval_target_batch.to(config.device) eval_cond_batch = eval_cond_batch.to(config.device) # append any location-specific parameters eval_cond_batch = state['location_params'](eval_cond_batch) # eval_batch = eval_batch.permute(0, 3, 1, 2) - eval_loss = eval_step_fn(state, eval_target_batch, eval_cond_batch) + eval_loss = eval_step_fn(state, eval_target_batch, eval_cond_batch, generator=g) # Progress val_set_loss += eval_loss.item() - val_set_loss = val_set_loss/len(eval_ds) + val_set_loss = val_set_loss/len(eval_dl) return val_set_loss @@ -112,7 +116,7 @@ def train(config, workdir): # Build dataloaders dataset_meta = DatasetMetadata(config.data.dataset_name) train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) - eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False, shuffle=False) # Initialize model. score_model = mutils.create_model(config) @@ -208,9 +212,7 @@ def train(config, workdir): val_set_loss = val_loss(config, eval_dl, eval_step_fn, state) epoch_metrics = {"epoch/train/loss": train_set_loss, "epoch/val/loss": val_set_loss} - logging.info("epoch: %d, val_loss: %.5e" % (state['epoch'], val_set_loss)) - writer.add_scalar("epoch/val/loss", val_set_loss, global_step=state['epoch']) - log_epoch(state['epoch'], epoch_metrics, wandb_run,writer) + log_epoch(state['epoch'], epoch_metrics, wandb_run, writer) if (state['epoch'] != 0 and state['epoch'] % config.training.snapshot_freq == 0) or state['epoch'] == num_train_epochs: # Save the checkpoint.