Skip to content

Commit

Permalink
Merge pull request #36 from henryaddison/deterministic-ncsnpp
Browse files Browse the repository at this point in the history
Deterministic training in score_sde_pytorch package
  • Loading branch information
henryaddison authored Aug 13, 2024
2 parents eaff5fa + bb37ccd commit 6fac894
Show file tree
Hide file tree
Showing 20 changed files with 706 additions and 80 deletions.
19 changes: 19 additions & 0 deletions bin/jasmin/lotus-wrapper
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#!/bin/bash
# Wrapper script around commands for interacting with a model to queue on LOTUS on JASMIN

module load gcc

source ~/.bashrc
mamba activate mv-mlde

set -euo pipefail

cd /home/users/vf20964/code/mlde

export DERIVED_DATA=/gws/nopw/j04/bris_climdyn/henrya/bp-backups/
export KK_SLACK_WH_URL=https://hooks.slack.com
export WANDB_EXPERIMENT_NAME="ml-downscaling-emulator"

nvidia-smi

$@
30 changes: 30 additions & 0 deletions bin/jasmin/queue-mlde
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#!/bin/bash
# Script for queueing a model job on LOTUS on JASMIN via lotus-wrapper script

set -euo pipefail

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

smem=128G
stime=1-00:00:00

while getopts ":m:t:" opt; do
case ${opt} in
m)
smem=${OPTARG}
;;
t)
stime=${OPTARG}
;;
\? )
# echo "Invalid option: -${OPTARG}" 1>&2
;;
: )
echo "Invalid option: $OPTARG requires an argument" 1>&2
exit 1
;;
esac
done
shift "$((OPTIND -1))"

sbatch --parsable --gres=gpu:1 --partition=orchid --account=orchid --time=${stime} --mem=${smem} -- ${SCRIPT_DIR}/lotus-wrapper $@
50 changes: 28 additions & 22 deletions bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

from ml_downscaling_emulator.score_sde_pytorch.models import cncsnpp # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import cunet # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import det_cunet # noqa: F401

from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401
layerspp, # noqa: F401
Expand Down Expand Up @@ -97,29 +98,34 @@ def _init_state(config):


def load_model(config, ckpt_filename):
if config.training.sde == "vesde":
sde = VESDE(
sigma_min=config.model.sigma_min,
sigma_max=config.model.sigma_max,
N=config.model.num_scales,
)
sampling_eps = 1e-5
elif config.training.sde == "vpsde":
sde = VPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales,
)
sampling_eps = 1e-3
elif config.training.sde == "subvpsde":
sde = subVPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales,
)
sampling_eps = 1e-3
deterministic = "deterministic" in config and config.deterministic
if deterministic:
sde = None
sampling_eps = 0
else:
raise RuntimeError(f"Unknown SDE {config.training.sde}")
if config.training.sde == "vesde":
sde = VESDE(
sigma_min=config.model.sigma_min,
sigma_max=config.model.sigma_max,
N=config.model.num_scales,
)
sampling_eps = 1e-5
elif config.training.sde == "vpsde":
sde = VPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales,
)
sampling_eps = 1e-3
elif config.training.sde == "subvpsde":
sde = subVPSDE(
beta_min=config.model.beta_min,
beta_max=config.model.beta_max,
N=config.model.num_scales,
)
sampling_eps = 1e-3
else:
raise RuntimeError(f"Unknown SDE {config.training.sde}")

# sigmas = mutils.get_sigmas(config) # noqa: F841
state = _init_state(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def get_config():
evaluate.batch_size = 64

config.data = data = ml_collections.ConfigDict()
data.dataset_name = (
"bham_gcmx-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season"
)
data.dataset_name = "bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr"
data.input_transform_key = "stan"
data.target_transform_key = "sqrturrecen"
data.input_transform_dataset = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,6 @@ def get_default_configs():

config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
config.deterministic = False

return config
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Modifications copyright 2024 Henry Addison
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Defaults for training in a deterministic fashion."""

import ml_collections
import torch

def get_default_configs():
config = ml_collections.ConfigDict()

config.deterministic = True

# training
config.training = training = ml_collections.ConfigDict()
training.n_epochs = 20
training.batch_size = 16
training.snapshot_freq = 25
training.log_freq = 500
training.eval_freq = 5000
## store additional checkpoints for preemption in cloud computing environments
training.snapshot_freq_for_preemption = 1000
## produce samples at each snapshot.
training.snapshot_sampling = False
training.likelihood_weighting = False
training.continuous = True
training.reduce_mean = False
training.random_crop_size = 0
training.continuous = True
training.reduce_mean = True
training.sde = ""

# sampling
config.sampling = sampling = ml_collections.ConfigDict()

# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.batch_size = 128

# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'UKCP_Local'
data.image_size = 64
data.random_flip = False
data.uniform_dequantization = False
data.time_inputs = False
data.centered = True
data.input_transform_key = "stan"
data.target_transform_key = "sqrturrecen"

# model
config.model = model = ml_collections.ConfigDict()
model.sigma_min = 0.01
model.sigma_max = 50
model.beta_min = 0.1
model.beta_max = 20.
model.loc_spec_channels = 0
model.num_scales = 1
model.ema_rate = 0.9999
model.dropout = 0.1
model.embedding_type = 'fourier'

# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.optimizer = 'Adam'
optim.lr = 2e-4
optim.beta1 = 0.9
optim.eps = 1e-8
optim.weight_decay = 0
optim.warmup = 5000
optim.grad_clip = 1.

config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Modifications copyright 2024 Henry Addison
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Training NCSN++ on precip data in a deterministic fashion."""

from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs

def get_config():
config = get_default_configs()

# training
training = config.training
training.n_epochs = 100

# data
data = config.data
data.dataset_name = 'bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr'

# model
model = config.model
model.name = 'cncsnpp'
model.loc_spec_channels = 0
model.dropout = 0.1
model.embedding_type = 'fourier'
model.scale_by_sigma = False
model.ema_rate = 0.9999
model.normalization = 'GroupNorm'
model.nonlinearity = 'swish'
model.nf = 128
model.ch_mult = (1, 2, 2, 2)
model.num_res_blocks = 4
model.attn_resolutions = (16,)
model.resamp_with_conv = True
model.conditional = True
model.fir = True
model.fir_kernel = [1, 3, 3, 1]
model.skip_rescale = True
model.resblock_type = 'biggan'
model.progressive = 'none'
model.progressive_input = 'residual'
model.progressive_combine = 'sum'
model.attention_type = 'ddpm'
model.embedding_type = 'positional'
model.init_scale = 0.
model.fourier_scale = 16
model.conv_size = 3

return config
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Modifications copyright 2024 Henry Addison
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Debug config for training a purely deterministic model.
This is opposed to using a model ready for score-based denoising
but training it in a deterministic fashion.
"""

from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs

def get_config():
config = get_default_configs()

# training
training = config.training
training.n_epochs = 100
training.snapshot_freq = 20
training.batch_size = 64

# data
data = config.data
data.dataset_name = 'bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr'
data.input_transform_key = "stan"
data.target_transform_key = "sqrturrecen"
data.input_transform_dataset = None
data.time_inputs = False

# model
model = config.model
model.name = 'det_cunet'
model.ema_rate = 1 # basically disables EMA

# optimizer
optim = config.optim
optim.optimizer = "Adam"
optim.lr = 2e-4
optim.beta1 = 0.9
optim.eps = 1e-8
optim.weight_decay = 0
optim.warmup = -1 # 5000
optim.grad_clip = -1. # 1.
return config
Loading

0 comments on commit 6fac894

Please sign in to comment.