From 66f892638e51b1a5935ea112f2c9a7d6bbf9b71b Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 13 Feb 2024 14:11:49 +0000 Subject: [PATCH 1/8] add a config for using models in a deterministic setup rather than score-sde/diffusion setting --- .../default_ukcp_local_pr_1em_configs.py | 1 + .../configs/deterministic/__init__.py | 0 .../configs/deterministic/default_configs.py | 86 +++++++++++++++++++ .../ukcp_local_pr_12em_cncsnpp.py | 62 +++++++++++++ 4 files changed, 149 insertions(+) create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/__init__.py create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/default_ukcp_local_pr_1em_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/default_ukcp_local_pr_1em_configs.py index 5939cc3aa..7c11b6ef6 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/default_ukcp_local_pr_1em_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/default_ukcp_local_pr_1em_configs.py @@ -74,5 +74,6 @@ def get_default_configs(): config.seed = 42 config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + config.deterministic = False return config diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/__init__.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py new file mode 100644 index 000000000..fb9c72037 --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py @@ -0,0 +1,86 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# Modifications copyright 2024 Henry Addison +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Defaults for training in a deterministic fashion.""" + +import ml_collections +import torch + +def get_default_configs(): + config = ml_collections.ConfigDict() + + config.deterministic = True + + # training + config.training = training = ml_collections.ConfigDict() + training.batch_size = 16#128 + training.snapshot_freq = 25 + training.log_freq = 50 + training.eval_freq = 1000 + ## store additional checkpoints for preemption in cloud computing environments + training.snapshot_freq_for_preemption = 1000 + ## produce samples at each snapshot. + training.snapshot_sampling = False + training.likelihood_weighting = False + training.continuous = True + training.reduce_mean = False + training.random_crop_size = 0 + training.continuous = True + training.reduce_mean = True + training.n_epochs = 20 + training.snapshot_freq = 5 + training.eval_freq = 5000 + training.sde = "" + + # sampling + config.sampling = sampling = ml_collections.ConfigDict() + + # evaluation + config.eval = evaluate = ml_collections.ConfigDict() + evaluate.batch_size = 128 + + # data + config.data = data = ml_collections.ConfigDict() + data.dataset = 'UKCP_Local' + data.image_size = 64 + data.random_flip = False + data.uniform_dequantization = False + data.time_inputs = False + data.centered = True + data.input_transform_key = "stan" + data.target_transform_key = "sqrturrecen" + + # model + config.model = model = ml_collections.ConfigDict() + model.loc_spec_channels = 0 + model.num_scales = 0 + model.ema_rate = 0.9999 + + # optimization + config.optim = optim = ml_collections.ConfigDict() + optim.weight_decay = 0 + optim.optimizer = 'Adam' + optim.lr = 2e-4 + optim.beta1 = 0.9 + optim.eps = 1e-8 + optim.warmup = 5000 + optim.grad_clip = 1. + + config.seed = 42 + config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') + + return config diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py new file mode 100644 index 000000000..bb0d55bc0 --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_cncsnpp.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# Modifications copyright 2024 Henry Addison +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Training NCSN++ on precip data in a deterministic fashion.""" + +from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs + +def get_config(): + config = get_default_configs() + + # training + training = config.training + training.n_epochs = 100 + + # data + data = config.data + data.dataset_name = 'bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr' + + # model + model = config.model + model.name = 'cncsnpp' + model.loc_spec_channels = 0 + model.dropout = 0.1 + model.embedding_type = 'fourier' + model.scale_by_sigma = False + model.ema_rate = 0.9999 + model.normalization = 'GroupNorm' + model.nonlinearity = 'swish' + model.nf = 128 + model.ch_mult = (1, 2, 2, 2) + model.num_res_blocks = 4 + model.attn_resolutions = (16,) + model.resamp_with_conv = True + model.conditional = True + model.fir = True + model.fir_kernel = [1, 3, 3, 1] + model.skip_rescale = True + model.resblock_type = 'biggan' + model.progressive = 'none' + model.progressive_input = 'residual' + model.progressive_combine = 'sum' + model.attention_type = 'ddpm' + model.embedding_type = 'positional' + model.init_scale = 0. + model.fourier_scale = 16 + model.conv_size = 3 + + return config From 7112b54fe8952299ac9f48e9fa8609300e32ccd2 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Sat, 3 Aug 2024 15:01:46 +0100 Subject: [PATCH 2/8] update training and sampling to handle deterministic approach as well as diffusion type models --- bin/predict.py | 48 ++++++------ .../score_sde_pytorch/losses.py | 47 +++++++++--- .../score_sde_pytorch/run_lib.py | 29 +++++--- .../score_sde_pytorch/sampling.py | 74 +++++++++++++------ 4 files changed, 133 insertions(+), 65 deletions(-) diff --git a/bin/predict.py b/bin/predict.py index 42924da54..6db6155a1 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -97,29 +97,33 @@ def _init_state(config): def load_model(config, ckpt_filename): - if config.training.sde == "vesde": - sde = VESDE( - sigma_min=config.model.sigma_min, - sigma_max=config.model.sigma_max, - N=config.model.num_scales, - ) - sampling_eps = 1e-5 - elif config.training.sde == "vpsde": - sde = VPSDE( - beta_min=config.model.beta_min, - beta_max=config.model.beta_max, - N=config.model.num_scales, - ) - sampling_eps = 1e-3 - elif config.training.sde == "subvpsde": - sde = subVPSDE( - beta_min=config.model.beta_min, - beta_max=config.model.beta_max, - N=config.model.num_scales, - ) - sampling_eps = 1e-3 + if config.deterministic: + sde = None + sampling_eps = 0 else: - raise RuntimeError(f"Unknown SDE {config.training.sde}") + if config.training.sde == "vesde": + sde = VESDE( + sigma_min=config.model.sigma_min, + sigma_max=config.model.sigma_max, + N=config.model.num_scales, + ) + sampling_eps = 1e-5 + elif config.training.sde == "vpsde": + sde = VPSDE( + beta_min=config.model.beta_min, + beta_max=config.model.beta_max, + N=config.model.num_scales, + ) + sampling_eps = 1e-3 + elif config.training.sde == "subvpsde": + sde = subVPSDE( + beta_min=config.model.beta_min, + beta_max=config.model.beta_max, + N=config.model.num_scales, + ) + sampling_eps = 1e-3 + else: + raise RuntimeError(f"Unknown SDE {config.training.sde}") # sigmas = mutils.get_sigmas(config) # noqa: F841 state = _init_state(config) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/losses.py b/src/ml_downscaling_emulator/score_sde_pytorch/losses.py index a0adc231b..a27bf7dae 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/losses.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/losses.py @@ -22,6 +22,7 @@ import torch import torch.optim as optim +import torch.nn.functional as F import numpy as np from .models import utils as mutils from .sde_lib import VESDE, VPSDE @@ -56,6 +57,28 @@ def optimize_fn(optimizer, params, step, lr=config.optim.lr, return optimize_fn +def get_deterministic_loss_fn(train, reduce_mean=True): + def loss_fn(model, batch, cond, generator=None): + """Compute the loss function for a deterministic run. + + Args: + model: A score model. + batch: A mini-batch of training/evaluation data to model. + cond: A mini-batch of conditioning inputs. + generator: An optional random number generator so can control the timesteps and initial noise samples used by loss function [ignored in train mode] + + Returns: + loss: A scalar that represents the average loss value across the mini-batch. + """ + # for deterministic model, do not use the time or target inputs - set to 0 always + x = torch.zeros_like(batch) + t = torch.zeros(batch.shape[0], device=batch.device) + pred = model(x, cond, t) + loss = F.mse_loss(pred, batch, reduction="mean") + return loss + + return loss_fn + def get_sde_loss_fn(sde, train, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5): """Create a loss function for training with arbirary SDEs. @@ -158,7 +181,7 @@ def loss_fn(model, batch): return loss_fn -def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False): +def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False, deterministic=False): """Create a one-step training/evaluation function. Args: @@ -168,21 +191,25 @@ def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True continuous: `True` indicates that the model is defined to take continuous time steps. likelihood_weighting: If `True`, weight the mixture of score matching losses according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper. + deterministic: If true, use deterministic mode loss, else use diffusion losses. Returns: A one-step function for training or evaluation. """ - if continuous: - loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean, - continuous=True, likelihood_weighting=likelihood_weighting) + if deterministic: + loss_fn = get_deterministic_loss_fn(train, reduce_mean=reduce_mean) else: - assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training." - if isinstance(sde, VESDE): - loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean) - elif isinstance(sde, VPSDE): - loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean) + if continuous: + loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean, + continuous=True, likelihood_weighting=likelihood_weighting) else: - raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") + assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training." + if isinstance(sde, VESDE): + loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean) + elif isinstance(sde, VPSDE): + loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean) + else: + raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.") def step_fn(state, batch, cond, generator=None): """Running one step of training or evaluation. diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 986a03865..ccfa986fa 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -139,17 +139,20 @@ def train(config, workdir): initial_epoch = int(state['epoch'])+1 # start from the epoch after the one currently reached # Setup SDEs - if config.training.sde.lower() == 'vpsde': - sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) - sampling_eps = 1e-3 - elif config.training.sde.lower() == 'subvpsde': - sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) - sampling_eps = 1e-3 - elif config.training.sde.lower() == 'vesde': - sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) - sampling_eps = 1e-5 + if config.deterministic: + sde = None else: - raise NotImplementedError(f"SDE {config.training.sde} unknown.") + if config.training.sde.lower() == 'vpsde': + sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + sampling_eps = 1e-3 + elif config.training.sde.lower() == 'subvpsde': + sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales) + sampling_eps = 1e-3 + elif config.training.sde.lower() == 'vesde': + sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales) + sampling_eps = 1e-5 + else: + raise NotImplementedError(f"SDE {config.training.sde} unknown.") # Build one-step training and evaluation functions optimize_fn = losses.optimization_manager(config) @@ -158,10 +161,12 @@ def train(config, workdir): likelihood_weighting = config.training.likelihood_weighting train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, - likelihood_weighting=likelihood_weighting) + likelihood_weighting=likelihood_weighting, + deterministic=config.deterministic,) eval_step_fn = losses.get_step_fn(sde, train=False, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, - likelihood_weighting=likelihood_weighting) + likelihood_weighting=likelihood_weighting, + deterministic=config.deterministic,) num_train_epochs = config.training.n_epochs diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py b/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py index 28ee1c7a8..8ee38c919 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py @@ -95,31 +95,34 @@ def get_sampling_fn(config, sde, shape, eps): trailing dimensions matching `shape`. """ - sampler_name = config.sampling.method - # Probability flow ODE sampling with black-box ODE solvers - if sampler_name.lower() == 'ode': - sampling_fn = get_ode_sampler(sde=sde, + if config.deterministic: + sampling_fn = get_deterministic_sampler(shape, device=config.device) + else: + sampler_name = config.sampling.method + # Probability flow ODE sampling with black-box ODE solvers + if sampler_name.lower() == 'ode': + sampling_fn = get_ode_sampler(sde=sde, + shape=shape, + denoise=config.sampling.noise_removal, + eps=eps, + device=config.device) + # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases. + elif sampler_name.lower() == 'pc': + predictor = get_predictor(config.sampling.predictor.lower()) + corrector = get_corrector(config.sampling.corrector.lower()) + sampling_fn = get_pc_sampler(sde=sde, shape=shape, + predictor=predictor, + corrector=corrector, + snr=config.sampling.snr, + n_steps=config.sampling.n_steps_each, + probability_flow=config.sampling.probability_flow, + continuous=config.training.continuous, denoise=config.sampling.noise_removal, eps=eps, device=config.device) - # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases. - elif sampler_name.lower() == 'pc': - predictor = get_predictor(config.sampling.predictor.lower()) - corrector = get_corrector(config.sampling.corrector.lower()) - sampling_fn = get_pc_sampler(sde=sde, - shape=shape, - predictor=predictor, - corrector=corrector, - snr=config.sampling.snr, - n_steps=config.sampling.n_steps_each, - probability_flow=config.sampling.probability_flow, - continuous=config.training.continuous, - denoise=config.sampling.noise_removal, - eps=eps, - device=config.device) - else: - raise ValueError(f"Sampler name {sampler_name} unknown.") + else: + raise ValueError(f"Sampler name {sampler_name} unknown.") return sampling_fn @@ -486,3 +489,32 @@ def ode_func(t, x): return x, nfe return ode_sampler + + +# + +def get_deterministic_sampler(shape, device='cuda'): + """The 'sampler' for a deterministic model. + + Args: + model: A deterministic model. + cond: A PyTorch tensor representing the conditioning inputs for this sample + Returns: + samples, number of function evaluations. + """ + def deterministic_sampler(model, cond): + with torch.no_grad(): + # Initial sample + # set batch size of output based on the conditioning input (since batches may vary in size) + output_shape = (cond.shape[0], *shape[1:]) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + x = torch.zeros(output_shape, device=device) + t = torch.zeros(output_shape[0], device=device) + # cond = cond.to(device) + + samples = model(x, cond, t) + + return samples, 1 + + return deterministic_sampler From fc1ccdcf4ac96e88388f309c10a545809dc4579e Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Sat, 3 Aug 2024 15:27:31 +0100 Subject: [PATCH 3/8] add a smoke test for debugging det models in the main module --- .../deterministic/ukcp_local_pr_debug.py | 42 +++++++++++++++++++ tests/smoke-test-det | 20 +++++++++ 2 files changed, 62 insertions(+) create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py create mode 100755 tests/smoke-test-det diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py new file mode 100644 index 000000000..2a8153fa4 --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_debug.py @@ -0,0 +1,42 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# Modifications copyright 2024 Henry Addison +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Debug config for training in a deterministic fashion.""" + +from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs + +def get_config(): + config = get_default_configs() + + # training + training = config.training + training.n_epochs = 2 + training.snapshot_freq = 5 + training.eval_freq = 100 + training.log_freq = 50 + training.batch_size = 2 + + # data + data = config.data + data.dataset_name = 'debug-sample' + data.time_inputs=True + + # model + model = config.model + model.name = 'cunet' + + return config diff --git a/tests/smoke-test-det b/tests/smoke-test-det new file mode 100755 index 000000000..b2262f995 --- /dev/null +++ b/tests/smoke-test-det @@ -0,0 +1,20 @@ +#! /usr/bin/env bash + +set -euo pipefail + +config_name="ukcp_local_pr_debug" + +workdir="output/test/deterministic/${config_name}/test-run" +config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/${config_name}.py" + +loc_spec_channels=0 + +rm -rf ${workdir} +WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.model.loc_spec_channels=${loc_spec_channels} + +num_samples=2 +eval_batch_size=32 +checkpoint="epoch_2" + +rm -rf "${workdir}/samples/${checkpoint}" +python bin/predict.py ${workdir} --dataset debug-sample --checkpoint ${checkpoint} --batch-size ${eval_batch_size} --num-samples ${num_samples} --ensemble-member 01 From 34369c5843ce7dd3398c4a2b219ebc5f99e6bf67 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Sat, 3 Aug 2024 16:23:35 +0100 Subject: [PATCH 4/8] add plain unet implementation into the score_sde_pytorch module in theory this should replace the mirroring deterministic package but obviously we should test this once blue pebble is back up properly also include configs for a now more tuned config for plain det unet and one that resembles the old det unet config more closely --- bin/predict.py | 1 + .../configs/ukcp_local_12em_pr_unet.py | 4 +- .../configs/deterministic/default_configs.py | 20 +++--- .../ukcp_local_pr_12em_plain_unet.py | 56 +++++++++++++++++ .../ukcp_local_pr_12em_tuned_plain_unet.py | 57 +++++++++++++++++ .../ukcp_local_pr_1em_cncsnpp.py | 62 +++++++++++++++++++ .../ukcp_local_pr_plain_unet_debug.py | 46 ++++++++++++++ .../score_sde_pytorch/models/det_cunet.py | 45 ++++++++++++++ .../score_sde_pytorch/run_lib.py | 6 +- tests/smoke-test-det-plain-unet | 21 +++++++ 10 files changed, 304 insertions(+), 14 deletions(-) create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_plain_unet_debug.py create mode 100644 src/ml_downscaling_emulator/score_sde_pytorch/models/det_cunet.py create mode 100755 tests/smoke-test-det-plain-unet diff --git a/bin/predict.py b/bin/predict.py index 6db6155a1..18b1b4828 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -36,6 +36,7 @@ from ml_downscaling_emulator.score_sde_pytorch.models import cncsnpp # noqa: F401 from ml_downscaling_emulator.score_sde_pytorch.models import cunet # noqa: F401 +from ml_downscaling_emulator.score_sde_pytorch.models import det_cunet # noqa: F401 from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401 layerspp, # noqa: F401 diff --git a/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py b/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py index 821accfd9..3cf49481a 100644 --- a/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py +++ b/src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py @@ -16,9 +16,7 @@ def get_config(): evaluate.batch_size = 64 config.data = data = ml_collections.ConfigDict() - data.dataset_name = ( - "bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season" - ) + data.dataset_name = "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr" data.input_transform_key = "stan" data.target_transform_key = "sqrturrecen" data.input_transform_dataset = None diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py index fb9c72037..12db7ee2e 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py @@ -27,10 +27,11 @@ def get_default_configs(): # training config.training = training = ml_collections.ConfigDict() - training.batch_size = 16#128 + training.n_epochs = 20 + training.batch_size = 16 training.snapshot_freq = 25 - training.log_freq = 50 - training.eval_freq = 1000 + training.log_freq = 500 + training.eval_freq = 5000 ## store additional checkpoints for preemption in cloud computing environments training.snapshot_freq_for_preemption = 1000 ## produce samples at each snapshot. @@ -41,9 +42,6 @@ def get_default_configs(): training.random_crop_size = 0 training.continuous = True training.reduce_mean = True - training.n_epochs = 20 - training.snapshot_freq = 5 - training.eval_freq = 5000 training.sde = "" # sampling @@ -66,17 +64,23 @@ def get_default_configs(): # model config.model = model = ml_collections.ConfigDict() + model.sigma_min = 0.01 + model.sigma_max = 50 + model.beta_min = 0.1 + model.beta_max = 20. model.loc_spec_channels = 0 - model.num_scales = 0 + model.num_scales = 1 model.ema_rate = 0.9999 + model.dropout = 0.1 + model.embedding_type = 'fourier' # optimization config.optim = optim = ml_collections.ConfigDict() - optim.weight_decay = 0 optim.optimizer = 'Adam' optim.lr = 2e-4 optim.beta1 = 0.9 optim.eps = 1e-8 + optim.weight_decay = 0 optim.warmup = 5000 optim.grad_clip = 1. diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py new file mode 100644 index 000000000..7b141272c --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py @@ -0,0 +1,56 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# Modifications copyright 2024 Henry Addison +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Debug config for training a purely deterministic model. + +This is opposed to using a model ready for score-based denoising +but training it in a deterministic fashion. +""" + +from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs + +def get_config(): + config = get_default_configs() + + # training + training = config.training + training.n_epochs = 100 + training.snapshot_freq = 20 + training.batch_size = 64 + + # data + data = config.data + data.dataset_name = 'bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr' + data.input_transform_key = "stan" + data.target_transform_key = "sqrturrecen" + data.input_transform_dataset = None + data.time_inputs = False + + # model + model = config.model + model.name = 'det_cunet' + + # optimizer + optim = config.optim + optim.optimizer = "Adam" + optim.lr = 2e-4 + optim.beta1 = 0.9 + optim.eps = 1e-8 + optim.weight_decay = 0 + optim.warmup = -1 # 5000 + optim.grad_clip = -1. # 1. + return config diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py new file mode 100644 index 000000000..d032d3175 --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py @@ -0,0 +1,57 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# Modifications copyright 2024 Henry Addison +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Debug config for training a purely deterministic model. + +This is opposed to using a model ready for score-based denoising +but training it in a deterministic fashion. +""" + +from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs + +def get_config(): + config = get_default_configs() + + # training + training = config.training + training.n_epochs = 100 + training.snapshot_freq = 20 + training.batch_size = 256 + + # data + data = config.data + data.dataset_name = 'bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr' + data.input_transform_key = "stan" + data.target_transform_key = "sqrturrecen" + data.input_transform_dataset = None + data.time_inputs = False + + # model + model = config.model + model.name = 'det_cunet' + model.ema_disabled = False + + # optimizer + optim = config.optim + optim.optimizer = "Adam" + optim.lr = 2e-4 + optim.beta1 = 0.9 + optim.eps = 1e-8 + optim.weight_decay = 0 + optim.warmup = 5000 + optim.grad_clip = 1. + return config diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py new file mode 100644 index 000000000..8a4422201 --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_1em_cncsnpp.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# Modifications copyright 2024 Henry Addison +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Training NCSN++ on precip data in a deterministic fashion.""" + +from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs + +def get_config(): + config = get_default_configs() + + # training + training = config.training + training.n_epochs = 100 + + # data + data = config.data + data.dataset_name = 'bham64_ccpm-4x_1em_psl-sphum4th-temp4th-vort4th_pr' + + # model + model = config.model + model.name = 'cncsnpp' + model.loc_spec_channels = 0 + model.dropout = 0.1 + model.embedding_type = 'fourier' + model.scale_by_sigma = False + model.ema_rate = 0.9999 + model.normalization = 'GroupNorm' + model.nonlinearity = 'swish' + model.nf = 128 + model.ch_mult = (1, 2, 2, 2) + model.num_res_blocks = 4 + model.attn_resolutions = (16,) + model.resamp_with_conv = True + model.conditional = True + model.fir = True + model.fir_kernel = [1, 3, 3, 1] + model.skip_rescale = True + model.resblock_type = 'biggan' + model.progressive = 'none' + model.progressive_input = 'residual' + model.progressive_combine = 'sum' + model.attention_type = 'ddpm' + model.embedding_type = 'positional' + model.init_scale = 0. + model.fourier_scale = 16 + model.conv_size = 3 + + return config diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_plain_unet_debug.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_plain_unet_debug.py new file mode 100644 index 000000000..aed312556 --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_plain_unet_debug.py @@ -0,0 +1,46 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# Modifications copyright 2024 Henry Addison +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Debug config for training a purely deterministic model. + +This is opposed to using a model ready for score-based denoising +but training it in a deterministic fashion. +""" + +from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs + +def get_config(): + config = get_default_configs() + + # training + training = config.training + training.n_epochs = 2 + training.snapshot_freq = 5 + training.eval_freq = 100 + training.log_freq = 50 + training.batch_size = 2 + + # data + data = config.data + data.dataset_name = 'debug-sample' + data.time_inputs = True + + # model + model = config.model + model.name = 'det_cunet' + + return config diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/det_cunet.py b/src/ml_downscaling_emulator/score_sde_pytorch/models/det_cunet.py new file mode 100644 index 000000000..2754c4a10 --- /dev/null +++ b/src/ml_downscaling_emulator/score_sde_pytorch/models/det_cunet.py @@ -0,0 +1,45 @@ +import logging +import torch.nn as nn + +from ml_downscaling_emulator.unet import unet +from . import utils + +def create_model(config, num_predictors): + if config.model.name == "u-net": + return unet.UNet(num_predictors, 1) + +from mlde_utils.training.dataset import get_variables + +###################################### +# !!!! DETERMINISTIC ONLY !!!! # +# This model does not use the time # +# or denoising channels at all # +###################################### + +@utils.register_model(name='det_cunet') +class DetPredNet(nn.Module): + """A purely deterministic plain U-Net with conditioning input.""" + + def __init__(self, config): + """Initialize a deterministic U-Net. + """ + if not config.deterministic: + logging.warning("Only use det_cunet for deterministic approach") + + super().__init__() + self.config = config + + cond_var_channels, output_channels = list(map(len, get_variables(config.data.dataset_name))) + if config.data.time_inputs: + cond_time_channels = 3 + else: + cond_time_channels = 0 + input_channels = cond_var_channels + cond_time_channels + config.model.loc_spec_channels + + self.unet = unet.UNet(input_channels, output_channels) + + def forward(self, x, cond, t): + """Forward of conditioning inputs through the deterministic U-Net model. + + Since not using the score-based, denoising approached, do not need to pass the time or the channels to be denoised to the model.""" + return self.unet(cond) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index ccfa986fa..8a254ad93 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -28,7 +28,7 @@ from codetiming import Timer import logging # Keep the import below for registering all model definitions -from .models import cunet, cncsnpp +from .models import det_cunet, cunet, cncsnpp from . import losses from .models.location_params import LocationParams from . import sampling @@ -100,15 +100,15 @@ def train(config, workdir): tb_dir = os.path.join(workdir, "tensorboard") os.makedirs(tb_dir, exist_ok=True) + run_name = os.path.basename(workdir) run_config = dict( dataset=config.data.dataset_name, input_transform_key=config.data.input_transform_key, target_transform_key=config.data.target_transform_key, architecture=config.model.name, sde=config.training.sde, - name=os.path.basename(workdir), + name=run_name, ) - run_name = os.path.basename(workdir) with track_run( EXPERIMENT_NAME, run_name, run_config, ["score_sde"], tb_dir diff --git a/tests/smoke-test-det-plain-unet b/tests/smoke-test-det-plain-unet new file mode 100755 index 000000000..336b5d961 --- /dev/null +++ b/tests/smoke-test-det-plain-unet @@ -0,0 +1,21 @@ +#! /usr/bin/env bash + +set -euo pipefail + +# config_name="ukcp_local_pr_debug" +config_name="ukcp_local_pr_pure_deterministic_debug" + +workdir="output/test/deterministic/${config_name}/test-run" +config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/${config_name}.py" + +loc_spec_channels=0 + +rm -rf ${workdir} +WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.model.loc_spec_channels=${loc_spec_channels} + +num_samples=2 +eval_batch_size=32 +checkpoint="epoch_2" + +rm -rf "${workdir}/samples/${checkpoint}" +python bin/predict.py ${workdir} --dataset debug-sample --checkpoint ${checkpoint} --batch-size ${eval_batch_size} --num-samples ${num_samples} --ensemble-member 01 From de2d0a9f28dff37840395fc3c8c331177edcf287 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 6 Aug 2024 13:12:47 +0100 Subject: [PATCH 5/8] add a way to disable the EMA updating so can basically disable EMA with a flag. This is another difference between u-net trained on score_sde side deterministically and the separate deterministic training approach."" In theory decay rate of 1 should allow this but it's complicated by a num_updates params too --- .../configs/deterministic/default_configs.py | 1 + .../ukcp_local_pr_12em_plain_unet.py | 1 + .../score_sde_pytorch/models/ema.py | 25 +++++++++++-------- .../score_sde_pytorch/run_lib.py | 3 ++- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py index 12db7ee2e..7c2c38ebd 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py @@ -71,6 +71,7 @@ def get_default_configs(): model.loc_spec_channels = 0 model.num_scales = 1 model.ema_rate = 0.9999 + model.ema_disabled = False model.dropout = 0.1 model.embedding_type = 'fourier' diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py index 7b141272c..50432dd1f 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py @@ -43,6 +43,7 @@ def get_config(): # model model = config.model model.name = 'det_cunet' + model.ema_disabled = True # optimizer optim = config.optim diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py index 47075fc8f..3bb26ae33 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py @@ -12,7 +12,7 @@ class ExponentialMovingAverage: Maintains (exponential) moving average of a set of parameters. """ - def __init__(self, parameters, decay, use_num_updates=True): + def __init__(self, parameters, decay, use_num_updates=True, disable_update=False): """ Args: parameters: Iterable of `torch.nn.Parameter`; usually the result of @@ -28,6 +28,7 @@ def __init__(self, parameters, decay, use_num_updates=True): self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = [] + self.disable_update = disable_update def update(self, parameters): """ @@ -40,15 +41,19 @@ def update(self, parameters): parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. """ - decay = self.decay - if self.num_updates is not None: - self.num_updates += 1 - decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) - one_minus_decay = 1.0 - decay - with torch.no_grad(): - parameters = [p for p in parameters if p.requires_grad] - for s_param, param in zip(self.shadow_params, parameters): - s_param.sub_(one_minus_decay * (s_param - param)) + if not self.disable_update: + decay = self.decay + if self.num_updates is not None: + self.num_updates += 1 + decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates)) + one_minus_decay = 1.0 - decay + with torch.no_grad(): + parameters = [p for p in parameters if p.requires_grad] + for s_param, param in zip(self.shadow_params, parameters): + s_param.sub_(one_minus_decay * (s_param - param)) + else: + # if disabled then just maintain a copy of the parameters + self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] def copy_to(self, parameters): """ diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index 8a254ad93..f391fab95 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -124,7 +124,8 @@ def train(config, workdir): location_params = LocationParams(config.model.loc_spec_channels, config.data.image_size) location_params = location_params.to(config.device) location_params = torch.nn.DataParallel(location_params) - ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate) + ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate, disable_update=config.model.ema_disabled) + optimizer = losses.get_optimizer(config, itertools.chain(score_model.parameters(), location_params.parameters())) state = dict(optimizer=optimizer, model=score_model, location_params=location_params, ema=ema, step=0, epoch=0) From 30733b113258fed3aa18bb559103afc52960d0df Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 9 Aug 2024 10:45:11 +0100 Subject: [PATCH 6/8] use decay/ema_rate to effectively disable EMA a rate of 1 means no EMA this is backwards compatible unlike adding a new config attribute --- .../configs/deterministic/default_configs.py | 1 - .../deterministic/ukcp_local_pr_12em_plain_unet.py | 2 +- .../ukcp_local_pr_12em_tuned_plain_unet.py | 1 - .../score_sde_pytorch/models/ema.py | 12 ++++++------ .../score_sde_pytorch/run_lib.py | 2 +- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py index 7c2c38ebd..12db7ee2e 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/default_configs.py @@ -71,7 +71,6 @@ def get_default_configs(): model.loc_spec_channels = 0 model.num_scales = 1 model.ema_rate = 0.9999 - model.ema_disabled = False model.dropout = 0.1 model.embedding_type = 'fourier' diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py index 50432dd1f..fc5ce7c6d 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_plain_unet.py @@ -43,7 +43,7 @@ def get_config(): # model model = config.model model.name = 'det_cunet' - model.ema_disabled = True + model.ema_rate = 1 # basically disables EMA # optimizer optim = config.optim diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py index d032d3175..2cde45932 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/ukcp_local_pr_12em_tuned_plain_unet.py @@ -43,7 +43,6 @@ def get_config(): # model model = config.model model.name = 'det_cunet' - model.ema_disabled = False # optimizer optim = config.optim diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py index 3bb26ae33..dd57fa3c1 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/models/ema.py @@ -12,7 +12,7 @@ class ExponentialMovingAverage: Maintains (exponential) moving average of a set of parameters. """ - def __init__(self, parameters, decay, use_num_updates=True, disable_update=False): + def __init__(self, parameters, decay, use_num_updates=True): """ Args: parameters: Iterable of `torch.nn.Parameter`; usually the result of @@ -28,7 +28,6 @@ def __init__(self, parameters, decay, use_num_updates=True, disable_update=False self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] self.collected_params = [] - self.disable_update = disable_update def update(self, parameters): """ @@ -41,7 +40,10 @@ def update(self, parameters): parameters: Iterable of `torch.nn.Parameter`; usually the same set of parameters used to initialize this object. """ - if not self.disable_update: + if self.decay == 1: + # complete decay so just maintain a copy of the parameters + self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] + else: decay = self.decay if self.num_updates is not None: self.num_updates += 1 @@ -51,9 +53,7 @@ def update(self, parameters): parameters = [p for p in parameters if p.requires_grad] for s_param, param in zip(self.shadow_params, parameters): s_param.sub_(one_minus_decay * (s_param - param)) - else: - # if disabled then just maintain a copy of the parameters - self.shadow_params = [p.clone().detach() for p in parameters if p.requires_grad] + def copy_to(self, parameters): """ diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index f391fab95..e185a675d 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -124,7 +124,7 @@ def train(config, workdir): location_params = LocationParams(config.model.loc_spec_channels, config.data.image_size) location_params = location_params.to(config.device) location_params = torch.nn.DataParallel(location_params) - ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate, disable_update=config.model.ema_disabled) + ema = ExponentialMovingAverage(itertools.chain(score_model.parameters(), location_params.parameters()), decay=config.model.ema_rate) optimizer = losses.get_optimizer(config, itertools.chain(score_model.parameters(), location_params.parameters())) state = dict(optimizer=optimizer, model=score_model, location_params=location_params, ema=ema, step=0, epoch=0) From 869127ccc8f02cdf83166097d2188c2535509b6b Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 5 Aug 2024 09:52:40 +0100 Subject: [PATCH 7/8] add helper scripts for queuing model jobs on jasmin --- bin/jasmin/lotus-wrapper | 19 +++++++++++++++++++ bin/jasmin/queue-mlde | 30 ++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) create mode 100755 bin/jasmin/lotus-wrapper create mode 100755 bin/jasmin/queue-mlde diff --git a/bin/jasmin/lotus-wrapper b/bin/jasmin/lotus-wrapper new file mode 100755 index 000000000..fe98b706c --- /dev/null +++ b/bin/jasmin/lotus-wrapper @@ -0,0 +1,19 @@ +#!/bin/bash +# Wrapper script around commands for interacting with a model to queue on LOTUS on JASMIN + +module load gcc + +source ~/.bashrc +mamba activate mv-mlde + +set -euo pipefail + +cd /home/users/vf20964/code/mlde + +export DERIVED_DATA=/gws/nopw/j04/bris_climdyn/henrya/bp-backups/ +export KK_SLACK_WH_URL=https://hooks.slack.com +export WANDB_EXPERIMENT_NAME="ml-downscaling-emulator" + +nvidia-smi + +$@ diff --git a/bin/jasmin/queue-mlde b/bin/jasmin/queue-mlde new file mode 100755 index 000000000..8f25cfc7b --- /dev/null +++ b/bin/jasmin/queue-mlde @@ -0,0 +1,30 @@ +#!/bin/bash +# Script for queueing a model job on LOTUS on JASMIN via lotus-wrapper script + +set -euo pipefail + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + +smem=128G +stime=1-00:00:00 + +while getopts ":m:t:" opt; do + case ${opt} in + m) + smem=${OPTARG} + ;; + t) + stime=${OPTARG} + ;; + \? ) + # echo "Invalid option: -${OPTARG}" 1>&2 + ;; + : ) + echo "Invalid option: $OPTARG requires an argument" 1>&2 + exit 1 + ;; + esac +done +shift "$((OPTIND -1))" + +sbatch --parsable --gres=gpu:1 --partition=orchid --account=orchid --time=${stime} --mem=${smem} -- ${SCRIPT_DIR}/lotus-wrapper $@ From bb37ccd87aa3950fb7260a4d5eb89a40839a1da1 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Fri, 9 Aug 2024 12:38:52 +0100 Subject: [PATCH 8/8] allow for missing deterministic key on config for backwards compatibility --- bin/predict.py | 3 ++- src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py | 7 ++++--- src/ml_downscaling_emulator/score_sde_pytorch/sampling.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/bin/predict.py b/bin/predict.py index 18b1b4828..169e07bb6 100644 --- a/bin/predict.py +++ b/bin/predict.py @@ -98,7 +98,8 @@ def _init_state(config): def load_model(config, ckpt_filename): - if config.deterministic: + deterministic = "deterministic" in config and config.deterministic + if deterministic: sde = None sampling_eps = 0 else: diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py index e185a675d..14c398cba 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/run_lib.py @@ -140,7 +140,8 @@ def train(config, workdir): initial_epoch = int(state['epoch'])+1 # start from the epoch after the one currently reached # Setup SDEs - if config.deterministic: + deterministic = "deterministic" in config and config.deterministic + if deterministic: sde = None else: if config.training.sde.lower() == 'vpsde': @@ -163,11 +164,11 @@ def train(config, workdir): train_step_fn = losses.get_step_fn(sde, train=True, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting, - deterministic=config.deterministic,) + deterministic=deterministic,) eval_step_fn = losses.get_step_fn(sde, train=False, optimize_fn=optimize_fn, reduce_mean=reduce_mean, continuous=continuous, likelihood_weighting=likelihood_weighting, - deterministic=config.deterministic,) + deterministic=deterministic,) num_train_epochs = config.training.n_epochs diff --git a/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py b/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py index 8ee38c919..556f0328f 100644 --- a/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py +++ b/src/ml_downscaling_emulator/score_sde_pytorch/sampling.py @@ -95,7 +95,7 @@ def get_sampling_fn(config, sde, shape, eps): trailing dimensions matching `shape`. """ - if config.deterministic: + if "deterministic" in config and config.deterministic: sampling_fn = get_deterministic_sampler(shape, device=config.device) else: sampler_name = config.sampling.method