From d60921e7e581ecd931b2578cabe3a412a18c4068 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Thu, 20 Jun 2024 02:59:24 +0100 Subject: [PATCH 1/4] attempt to stabilize validation loss by using the same random numbers for computing loss when not in train mode --- .../score_sde_pytorch/losses.py | 16 +++++++++++----- .../score_sde_pytorch/run_lib.py | 16 ++++++++++------ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/losses.py b/src/ml_downscaling_emulator/score_sde_pytorch/losses.py index 19b76bd95..a0adc231b 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/losses.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/losses.py @@ -74,7 +74,7 @@ def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_we """ reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs) - def loss_fn(model, batch, cond): + def loss_fn(model, batch, cond, generator=None): """Compute the loss function. Args: @@ -86,8 +86,13 @@ def loss_fn(model, batch, cond): loss: A scalar that represents the average loss value across the mini-batch. """ score_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous) - t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps - z = torch.randn_like(batch) + + if train: + t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps + z = torch.randn_like(batch) + else: + t = torch.rand(batch.shape[0], device=batch.device, generator=generator) * (sde.T - eps) + eps + z = torch.empty_like(batch).normal_(generator=generator) mean, std = sde.marginal_prob(batch, t) perturbed_data = mean + std[:, None, None, None] * z score = score_fn(perturbed_data, cond, t) @@ -179,7 +184,7 @@ def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True else: raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") - def step_fn(state, batch, cond): + def step_fn(state, batch, cond, generator=None): """Running one step of training or evaluation. This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together @@ -190,6 +195,7 @@ def step_fn(state, batch, cond): EMA status, and number of optimization steps. batch: A mini-batch of training/evaluation data to model. cond: A mini-batch of conditioning inputs. + generator: An optional random number generator so can control the timesteps and initial noise samples used by loss function [ignored in train mode] Returns: loss: The average loss value of this state. @@ -208,7 +214,7 @@ def step_fn(state, batch, cond): ema = state['ema'] ema.store(model.parameters()) ema.copy_to(model.parameters()) - loss = loss_fn(model, batch, cond) + loss = loss_fn(model, batch, cond, generator=generator) ema.restore(model.parameters()) return loss diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 998df6920..03a1fd28d 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -52,20 +52,24 @@ EXPERIMENT_NAME = os.getenv("WANDB_EXPERIMENT_NAME") -def val_loss(config, eval_ds, eval_step_fn, state): +def val_loss(config, eval_dl, eval_step_fn, state): val_set_loss = 0.0 - for eval_cond_batch, eval_target_batch, eval_time_batch in eval_ds: - # eval_cond_batch, eval_target_batch = next(iter(eval_ds)) + # use a consistent generator for computing validation set loss + # so value is not down to vagaries of random choice of initial noise samples or schedules + g = torch.Generator(device=config.device) + g.manual_seed(42) + for eval_cond_batch, eval_target_batch, eval_time_batch in eval_dl: + # eval_cond_batch, eval_target_batch = next(iter(eval_dl)) eval_target_batch = eval_target_batch.to(config.device) eval_cond_batch = eval_cond_batch.to(config.device) # append any location-specific parameters eval_cond_batch = state['location_params'](eval_cond_batch) # eval_batch = eval_batch.permute(0, 3, 1, 2) - eval_loss = eval_step_fn(state, eval_target_batch, eval_cond_batch) + eval_loss = eval_step_fn(state, eval_target_batch, eval_cond_batch, generator=g) # Progress val_set_loss += eval_loss.item() - val_set_loss = val_set_loss/len(eval_ds) + val_set_loss = val_set_loss/len(eval_dl) return val_set_loss @@ -112,7 +116,7 @@ def train(config, workdir): # Build dataloaders dataset_meta = DatasetMetadata(config.data.dataset_name) train_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="train", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) - eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False) + eval_dl, _, _ = get_dataloader(config.data.dataset_name, config.data.dataset_name, config.data.dataset_name, config.data.input_transform_key, config.data.target_transform_key, transform_dir, batch_size=config.training.batch_size, split="val", ensemble_members=dataset_meta.ensemble_members(), include_time_inputs=config.data.time_inputs, evaluation=False, shuffle=False) # Initialize model. score_model = mutils.create_model(config) From ecff87c9246cf214d71cbd91013f51c9718eaea5 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 24 Jun 2024 14:18:02 +0100 Subject: [PATCH 2/4] compute and log val loss before 1st epoch partly for debugging but also to see improvements --- src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 03a1fd28d..327e2f354 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -171,6 +171,12 @@ def train(config, workdir): if config.training.random_crop_size > 0: random_crop = torchvision.transforms.RandomCrop(config.training.random_crop_size) + # log val loss before any training + if int(state['epoch']) == 0: + val_set_loss = val_loss(config, eval_dl, eval_step_fn, state) + epoch_metrics = {"epoch/val/loss": val_set_loss} + log_epoch(state['epoch'], epoch_metrics, wandb_run, writer) + for epoch in range(initial_epoch, num_train_epochs + 1): state['epoch'] = epoch train_set_loss = 0.0 From 48ac06b99e43c01cad21f3cec8b2c5f9524fd30d Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 24 Jun 2024 14:18:30 +0100 Subject: [PATCH 3/4] don't duplicate logging for validation loss each epoch --- src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 327e2f354..0aefb72ae 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -218,9 +218,7 @@ def train(config, workdir): val_set_loss = val_loss(config, eval_dl, eval_step_fn, state) epoch_metrics = {"epoch/train/loss": train_set_loss, "epoch/val/loss": val_set_loss} - logging.info("epoch: %d, val_loss: %.5e" % (state['epoch'], val_set_loss)) - writer.add_scalar("epoch/val/loss", val_set_loss, global_step=state['epoch']) - log_epoch(state['epoch'], epoch_metrics, wandb_run,writer) + log_epoch(state['epoch'], epoch_metrics, wandb_run, writer) if (state['epoch'] != 0 and state['epoch'] % config.training.snapshot_freq == 0) or state['epoch'] == num_train_epochs: # Save the checkpoint. From 6ea7640fbde1181d5328e28e5a3c35aaa3939a61 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 8 Jul 2024 21:39:22 +0100 Subject: [PATCH 4/4] don't both recording val loss before any training it's large --- src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 0aefb72ae..986a03865 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -171,12 +171,6 @@ def train(config, workdir): if config.training.random_crop_size > 0: random_crop = torchvision.transforms.RandomCrop(config.training.random_crop_size) - # log val loss before any training - if int(state['epoch']) == 0: - val_set_loss = val_loss(config, eval_dl, eval_step_fn, state) - epoch_metrics = {"epoch/val/loss": val_set_loss} - log_epoch(state['epoch'], epoch_metrics, wandb_run, writer) - for epoch in range(initial_epoch, num_train_epochs + 1): state['epoch'] = epoch train_set_loss = 0.0