Skip to content

Commit

Permalink
Merge pull request #33 from henryaddison/paper
Browse files Browse the repository at this point in the history
Prep code for paper publishing
  • Loading branch information
henryaddison authored May 9, 2024
2 parents 5ea8e6a + 3783052 commit 20a3a46
Show file tree
Hide file tree
Showing 133 changed files with 380 additions and 4,022 deletions.
6 changes: 6 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# [required] root path for where to find datasets and store models
DERIVED_DATA=/path/to/derived_data
# [optional] log level
LOG_LEVEL=INFO
# [optional] slack webhook url for training and samples notifications
KK_SLACK_WH_URL=https://hooks.slack.com
33 changes: 33 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: CI

on: [push]

jobs:
ci-checks:
runs-on: ubuntu-latest
strategy:
max-parallel: 5

steps:
- name: Clone repo
uses: actions/checkout@v4
- name: setup-micromamba
uses: mamba-org/[email protected]
with:
environment-file: environment.lock.yml
init-shell: bash
cache-environment: true
post-cleanup: 'all'
- name: Install package
run: |
pip install -e .
shell: micromamba-shell {0}
- name: Install unet
uses: actions/checkout@v4
with:
repository: henryaddison/Pytorch-UNet
path: src/ml_downscaling_emulator/unet
- name: Test with pytest
run: |
pytest
shell: micromamba-shell {0}
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ repos:
hooks:
- id: black
language_version: python3.9
exclude: ^src/ml_downscaling_emulator/score_sde_pytorch_hja22/
exclude: ^src/ml_downscaling_emulator/score_sde_pytorch/
- repo: https://github.com/pycqa/flake8
rev: '6.0.0' # pick a git hash / tag to point to
hooks:
- id: flake8
exclude: ^src/ml_downscaling_emulator/score_sde_pytorch_hja22/
exclude: ^src/ml_downscaling_emulator/score_sde_pytorch/
104 changes: 50 additions & 54 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,79 +1,75 @@
# ML Downscaling Emulator
# ML Downscaling Emulator

Forked from PyTorch implementation for the paper [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS)
A machine learning emulator of a CPM based on a diffusion model.

by [Yang Song](https://yang-song.github.io), [Jascha Sohl-Dickstein](http://www.sohldickstein.com/), [Diederik P. Kingma](http://dpkingma.com/), [Abhishek Kumar](http://users.umiacs.umd.edu/~abhishek/), [Stefano Ermon](https://cs.stanford.edu/~ermon/), and [Ben Poole](https://cs.stanford.edu/~poole/)
Diffusion model implementation forked from PyTorch implementation for the paper [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) by [Yang Song](https://yang-song.github.io), [Jascha Sohl-Dickstein](http://www.sohldickstein.com/), [Diederik P. Kingma](http://dpkingma.com/), [Abhishek Kumar](http://users.umiacs.umd.edu/~abhishek/), [Stefano Ermon](https://cs.stanford.edu/~ermon/), and [Ben Poole](https://cs.stanford.edu/~poole/).

## Dependencies

1. Create conda environment: `conda env create -f environment.lock.yml`
2. Clone and install https://github.com/henryaddison/mlde_utils into the environment: e.g. `pip install -e ../mlde_utils`
3. Install ml_downscaling_emulator locally: `pip install -e .`
4. Install unet code: `git clone --depth 1 [email protected]:henryaddison/Pytorch-UNet src/ml_downscaling_emulator/unet`
5. Configure necessary environment variables: `DERVIED_DATA` and `KK_SLACK_WH_URL`
1. Clone repo and cd into it
2. Create conda environment: `conda env create -f environment.lock.yml` (or add dependencies to your own: `conda env install -f environment.txt`)
3. Activate the conda environment (if not already done so)
4. Install ml_downscaling_emulator locally: `pip install -e .`
5. Install unet code: `git clone --depth 1 [email protected]:henryaddison/Pytorch-UNet src/ml_downscaling_emulator/unet`
6. Configure application behaviour with environment variables. See `.env.example` for variables that can be set.

### Usage
Any datasets are assumed to be found in `${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/`. In particular, the config key config.data.dataset_name is the name of the dataset to use to train the model.

## Usage

#### Smoke test
### Smoke test

`bin/local-test-train`
```sh
tests/smoke-test
```

Uses a simpler network to test the full training and sampling regime.
Recommended to run with a sample of the dataset.

#### Training
### Training

Train models through `main.py`.
Train models through `bin/main.py`, e.g.

```sh
python bin/main.py --config src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERVIED_DATA}/path/to/models/paper-12em --mode train
```

```sh
main.py:
--mode: <train>: Running mode: train
--workdir: Working directory for storing data related to model such as model snapshots, tranforms or samples
--config: Training configuration.
(default: 'None')
--mode: <train|eval>: Running mode: train or eval
--workdir: Working directory
```

* `config` is the path to the config file. Our prescribed config files are provided in `configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and should be quite self-explanatory.
* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta`.

* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, transforms and samples. Recommended to be a subdirectory of ${DERIVED_DATA}.

* `config` is the path to the config file. Config files for emulators are provided in `src/configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and heavily based on ncsnpp config files.

**Naming conventions of config files**: the path of a config file is a combination of the following dimensions:
* dataset: One of `cifar10`, `celeba`, `celebahq`, `celebahq_256`, `ffhq_256`, `celebahq`, `ffhq`.
* model: One of `ncsn`, `ncsnv2`, `ncsnpp`, `ddpm`, `ddpmpp`.
* SDE: `subvpsde`
* data source: `ukcp_local`
* variable: `pr`
* ensemble members: `12em` (all 12) or `1em` (single)
* model: `cncsnpp`
* continuous: train the model with continuously sampled time steps.

* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.

* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta` .

These functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. For example, to generate samples and evaluate sample quality, supply the `--config.eval.enable_sampling` flag; to compute log-likelihoods, supply the `--config.eval.enable_bpd` flag, and specify `--config.eval.dataset=train/test` to indicate whether to compute the likelihoods on the training or test dataset.

#### Sampling
TODO

## How to extend the code
* **New SDEs**: inherent the `sde_lib.SDE` abstract class and implement all abstract methods. The `discretize()` method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE.
* **New predictors**: inherent the `sampling.Predictor` abstract class, implement the `update_fn` abstract method, and register its name with `@register_predictor`. The new predictor can be directly used in `sampling.get_pc_sampler` for Predictor-Corrector sampling, and all other controllable generation methods in `controllable_generation.py`.
* **New correctors**: inherent the `sampling.Corrector` abstract class, implement the `update_fn` abstract method, and register its name with `@register_corrector`. The new corrector can be directly used in `sampling.get_pc_sampler`, and all other controllable generation methods in `controllable_generation.py`.

## Tips
* When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via `config.training.n_jitted_steps`. For CIFAR-10, we recommend using `config.training.n_jitted_steps=5` when your GPU/TPU has sufficient memory; otherwise we recommend using `config.training.n_jitted_steps=1`. Our current implementation requires `config.training.log_freq` to be dividable by `n_jitted_steps` for logging and checkpointing to work normally.
* The `snr` (signal-to-noise ratio) parameter of `LangevinCorrector` somewhat behaves like a temperature parameter. Larger `snr` typically results in smoother samples, while smaller `snr` gives more diverse but lower quality samples. Typical values of `snr` is `0.05 - 0.2`, and it requires tuning to strike the sweet spot.
* For VE SDEs, we recommend choosing `config.model.sigma_max` to be the maximum pairwise distance between data samples in the training dataset.

## References

This code based on the following work:
```bib
@inproceedings{
song2021scorebased,
title={Score-Based Generative Modeling through Stochastic Differential Equations},
author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=PxTIG12RRHS}
}
```
Functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package.


### Sampling

This work is built upon some previous papers which might also interest you:
Once have trained a model create samples from it with `bin/predict.py`, e.g.

```sh
python bin/predict.py --checkpoint epoch-20 --dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --split test --ensemble-member 01 --input-transform-dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --input-transform-key pixelmmsstan --num-samples 1 ${DERVIED_DATA}/path/to/models/paper-12em
```

* Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." *Proceedings of the 33rd Annual Conference on Neural Information Processing Systems*. 2019.
* Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020.
* Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020.
This example command will:
* use the checkpoint of the model in `${DERVIED_DATA}/path/to/models/paper-12em/checkpoints/{checkpoint}.pth` and model config from training `${DERVIED_DATA}/path/to/models/paper-12em/config.yml`.
* store samples generated in `${DERVIED_DATA}/path/to/models/paper-12em/samples/{dataset}/{input_transform_data}-{input_transform_key}/{split}/{ensemble_member}/`. Sample files and named like `predictions-{uuid}.nc`.
* generate samples conditioned on examples from ensemble member `01` in the `test` subset of the `bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season` dataset.
* transform the inputs based on the `bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season` dataset using the `pixelmmsstan` approach.
* generate 1 set of samples.
Binary file removed assets/bedroom.jpeg
Binary file not shown.
Binary file removed assets/celebahq_256.jpg
Binary file not shown.
Binary file removed assets/church.jpeg
Binary file not shown.
Binary file removed assets/ffhq_1024.jpeg
Binary file not shown.
Binary file removed assets/ffhq_256.jpg
Binary file not shown.
Binary file removed assets/ffhq_samples.jpg
Binary file not shown.
Binary file removed assets/schematic.jpg
Binary file not shown.
119 changes: 0 additions & 119 deletions bin/add-ensemble-member-dim-to-predictions

This file was deleted.

6 changes: 0 additions & 6 deletions bin/bp-jup

This file was deleted.

File renamed without changes.
File renamed without changes.
10 changes: 4 additions & 6 deletions bin/queue-training → bin/bp/queue-training
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@ import typer
app = typer.Typer()


def train_cmd(sde, dataset, workdir, config, config_overrides=list):
def train_cmd(sde, workdir, config, config_overrides=list):
train_basecmd = ["python", f"bin/main.py"]

train_opts = {
"--config": f"src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/{sde}/{config}.py",
"--config": f"src/ml_downscaling_emulator/score_sde_pytorch/configs/{sde}/{config}.py",
"--workdir": workdir,
"--mode": "train",
}

return (
train_basecmd
+ [arg for item in train_opts.items() for arg in item]
+ [f"--config.data.dataset_name={dataset}"]
+ config_overrides
)

Expand Down Expand Up @@ -51,9 +50,8 @@ def queue_cmd(duration, memory):
def main(
ctx: typer.Context,
model_run_id: str,
cpm_dataset: str,
sde: str,
config: str = "xarray_12em_cncsnpp_continuous",
config: str = "ukcp_local_pr_12em_cncsnpp_continuous",
memory: int = 64,
duration: int = 72,
):
Expand All @@ -67,7 +65,7 @@ def main(
full_cmd = (
queue_cmd(duration=duration, memory=memory)
+ ["--"]
+ train_cmd(sde, cpm_dataset, workdir, config, ctx.args)
+ train_cmd(sde, workdir, config, ctx.args)
)
print(" ".join(full_cmd).strip(), file=sys.stderr)
output = subprocess.run(full_cmd, capture_output=True)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
26 changes: 0 additions & 26 deletions bin/deterministic/local-test-train

This file was deleted.

2 changes: 1 addition & 1 deletion bin/deterministic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"config", None, "Training configuration.", lock_config=True
)
flags.DEFINE_string("workdir", None, "Work directory.")
flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval")
flags.DEFINE_enum("mode", None, ["train"], "Running mode: train.")
flags.mark_flags_as_required(["workdir", "config", "mode"])


Expand Down
Loading

0 comments on commit 20a3a46

Please sign in to comment.