Skip to content

Commit

Permalink
a few updates to the README instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Jun 12, 2024
1 parent b6a6e95 commit a731e9b
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

A machine learning emulator of a CPM based on a diffusion model.

This is the code for the paper Addison et al. (2024) "Machine learning emulation of precipitation from km-scale regional climate simulations using a diffusion model".

Diffusion model implementation forked from PyTorch implementation for the paper [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) by [Yang Song](https://yang-song.github.io), [Jascha Sohl-Dickstein](http://www.sohldickstein.com/), [Diederik P. Kingma](http://dpkingma.com/), [Abhishek Kumar](http://users.umiacs.umd.edu/~abhishek/), [Stefano Ermon](https://cs.stanford.edu/~ermon/), and [Ben Poole](https://cs.stanford.edu/~poole/).

## Dependencies
Expand All @@ -10,12 +12,20 @@ Diffusion model implementation forked from PyTorch implementation for the paper
2. Create conda environment: `conda env create -f environment.lock.yml` (or add dependencies to your own: `conda env install -f environment.txt`)
3. Activate the conda environment (if not already done so)
4. Install ml_downscaling_emulator locally: `pip install -e .`
5. Install U-Net code: `git clone --depth 1 https://github.com/henryaddison/Pytorch-UNet.git src/ml_downscaling_emulator/unet`
5. \[Optional\] Install U-Net code: `git clone --depth 1 https://github.com/henryaddison/Pytorch-UNet.git src/ml_downscaling_emulator/unet` - this is only necessary if you wish to use the deterministic comparison models.
6. Configure application behaviour with environment variables. See `.env.example` for variables that can be set.

Any datasets are assumed to be found in `${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/`. In particular, the config key config.data.dataset_name is the name of the dataset to use to train the model.

## Usage
## Diffusion Model Usage

### Data

Datasets for use with the emulator can be created using [[https://github.com/henryaddison/mlde-data]].
This repo contains further information about dataset specification.
The datasets used in the paper can be found on [Zenodo](https://doi.org/10.5281/zenodo.11504859).

**NB** the interface commonly takes just the name of a dataset name. It is expected to be found at `${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/` (where DERIVED_DATA is a configurable environment variable).

### Smoke test

Expand All @@ -28,7 +38,7 @@ Recommended to run with a sample of the dataset.

### Training

Train models through `bin/main.py`, e.g.
Train models through `bin/main.py`, e.g. to train the model used in the paper use

```sh
python bin/main.py --config src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERVIED_DATA}/path/to/models/paper-12em --mode train
Expand Down Expand Up @@ -64,12 +74,12 @@ Functionalities can be configured through config files, or more conveniently, th
Once have trained a model create samples from it with `bin/predict.py`, e.g.

```sh
python bin/predict.py --checkpoint epoch-20 --dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --split test --ensemble-member 01 --input-transform-dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --input-transform-key pixelmmsstan --num-samples 1 ${DERVIED_DATA}/path/to/models/paper-12em
python bin/predict.py --checkpoint epoch_20 --dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --split test --ensemble-member 01 --input-transform-dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --input-transform-key pixelmmsstan --num-samples 1 ${DERVIED_DATA}/path/to/models/paper-12em
```

This example command will:
* use the checkpoint of the model in `${DERVIED_DATA}/path/to/models/paper-12em/checkpoints/{checkpoint}.pth` and model config from training `${DERVIED_DATA}/path/to/models/paper-12em/config.yml`.
* store samples generated in `${DERVIED_DATA}/path/to/models/paper-12em/samples/{dataset}/{input_transform_data}-{input_transform_key}/{split}/{ensemble_member}/`. Sample files and named like `predictions-{uuid}.nc`.
* store samples generated in `${DERVIED_DATA}/path/to/models/paper-12em/samples/{dataset}/{input_transform_data}-{input_transform_key}/{split}/{ensemble_member}/`. Sample files ar named like `predictions-{uuid}.nc`.
* generate samples conditioned on examples from ensemble member `01` in the `test` subset of the `bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season` dataset.
* transform the inputs based on the `bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season` dataset using the `pixelmmsstan` approach.
* generate 1 set of samples.

0 comments on commit a731e9b

Please sign in to comment.