A machine learning emulator of a CPM based on a diffusion model.
This is the code for the paper Addison et al. (2024) "Machine learning emulation of precipitation from km-scale regional climate simulations using a diffusion model".
Diffusion model implementation forked from PyTorch implementation for the paper Score-Based Generative Modeling through Stochastic Differential Equations by Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole.
- Clone repo and cd into it
- Create conda environment:
conda env create -f environment.lock.yml
(or add dependencies to your own:conda env install -f environment.txt
) - Activate the conda environment (if not already done so)
- Install ml_downscaling_emulator locally:
pip install -e .
- [Optional] Install U-Net code:
git clone --depth 1 https://github.com/henryaddison/Pytorch-UNet.git src/ml_downscaling_emulator/unet
- this is only necessary if you wish to use the deterministic comparison models. - Configure application behaviour with environment variables. See
.env.example
for variables that can be set.
Any datasets are assumed to be found in ${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/
. In particular, the config key config.data.dataset_name is the name of the dataset to use to train the model.
Datasets for use with the emulator can be created using [[https://github.com/henryaddison/mlde-data]]. This repo contains further information about dataset specification. The datasets used in the paper can be found on Zenodo.
NB the interface commonly takes just the name of a dataset name. It is expected to be found at ${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/
(where DERIVED_DATA is a configurable environment variable).
tests/smoke-test
Uses a simpler network to test the full training and sampling regime. Recommended to run with a sample of the dataset.
Train models through bin/main.py
, e.g. to train the model used in the paper use
python bin/main.py --config src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERIVED_DATA}/path/to/models/paper-12em --mode train
main.py:
--mode: <train>: Running mode: train
--workdir: Working directory for storing data related to model such as model snapshots, tranforms or samples
--config: Training configuration.
(default: 'None')
-
mode
is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist inworkdir/checkpoints-meta
. -
workdir
is the path that stores all artifacts of one experiment, like checkpoints, transforms and samples. Recommended to be a subdirectory of ${DERIVED_DATA}. -
config
is the path to the config file. Config files for emulators are provided insrc/configs/
. They are formatted according toml_collections
and heavily based on ncsnpp config files.Naming conventions of config files: the path of a config file is a combination of the following dimensions:
- SDE:
subvpsde
- data source:
ukcp_local
- variable:
pr
- ensemble members:
12em
(all 12) or1em
(single) - model:
cncsnpp
- continuous: train the model with continuously sampled time steps.
- SDE:
Functionalities can be configured through config files, or more conveniently, through the command-line support of the ml_collections
package.
Once have trained a model create samples from it with bin/predict.py
, e.g.
python bin/predict.py --checkpoint epoch_20 --dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --split test --ensemble-member 01 --input-transform-dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --input-transform-key pixelmmsstan --num-samples 1 ${DERIVED_DATA}/path/to/models/paper-12em
This example command will:
- use the checkpoint of the model in
${DERIVED_DATA}/path/to/models/paper-12em/checkpoints/{checkpoint}.pth
and model config from training${DERIVED_DATA}/path/to/models/paper-12em/config.yml
. - store samples generated in
${DERIVED_DATA}/path/to/models/paper-12em/samples/{dataset}/{input_transform_data}-{input_transform_key}/{split}/{ensemble_member}/
. Sample files ar named likepredictions-{uuid}.nc
. - generate samples conditioned on examples from ensemble member
01
in thetest
subset of thebham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season
dataset. - transform the inputs based on the
bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season
dataset using thepixelmmsstan
approach. - generate 1 set of samples.