forked from yang-song/score_sde_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #33 from henryaddison/paper
Prep code for paper publishing
- Loading branch information
Showing
133 changed files
with
380 additions
and
4,022 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# [required] root path for where to find datasets and store models | ||
DERIVED_DATA=/path/to/derived_data | ||
# [optional] log level | ||
LOG_LEVEL=INFO | ||
# [optional] slack webhook url for training and samples notifications | ||
KK_SLACK_WH_URL=https://hooks.slack.com |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
name: CI | ||
|
||
on: [push] | ||
|
||
jobs: | ||
ci-checks: | ||
runs-on: ubuntu-latest | ||
strategy: | ||
max-parallel: 5 | ||
|
||
steps: | ||
- name: Clone repo | ||
uses: actions/checkout@v4 | ||
- name: setup-micromamba | ||
uses: mamba-org/[email protected] | ||
with: | ||
environment-file: environment.lock.yml | ||
init-shell: bash | ||
cache-environment: true | ||
post-cleanup: 'all' | ||
- name: Install package | ||
run: | | ||
pip install -e . | ||
shell: micromamba-shell {0} | ||
- name: Install unet | ||
uses: actions/checkout@v4 | ||
with: | ||
repository: henryaddison/Pytorch-UNet | ||
path: src/ml_downscaling_emulator/unet | ||
- name: Test with pytest | ||
run: | | ||
pytest | ||
shell: micromamba-shell {0} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,79 +1,75 @@ | ||
# ML Downscaling Emulator | ||
# ML Downscaling Emulator | ||
|
||
Forked from PyTorch implementation for the paper [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) | ||
A machine learning emulator of a CPM based on a diffusion model. | ||
|
||
by [Yang Song](https://yang-song.github.io), [Jascha Sohl-Dickstein](http://www.sohldickstein.com/), [Diederik P. Kingma](http://dpkingma.com/), [Abhishek Kumar](http://users.umiacs.umd.edu/~abhishek/), [Stefano Ermon](https://cs.stanford.edu/~ermon/), and [Ben Poole](https://cs.stanford.edu/~poole/) | ||
Diffusion model implementation forked from PyTorch implementation for the paper [Score-Based Generative Modeling through Stochastic Differential Equations](https://openreview.net/forum?id=PxTIG12RRHS) by [Yang Song](https://yang-song.github.io), [Jascha Sohl-Dickstein](http://www.sohldickstein.com/), [Diederik P. Kingma](http://dpkingma.com/), [Abhishek Kumar](http://users.umiacs.umd.edu/~abhishek/), [Stefano Ermon](https://cs.stanford.edu/~ermon/), and [Ben Poole](https://cs.stanford.edu/~poole/). | ||
|
||
## Dependencies | ||
|
||
1. Create conda environment: `conda env create -f environment.lock.yml` | ||
2. Clone and install https://github.com/henryaddison/mlde_utils into the environment: e.g. `pip install -e ../mlde_utils` | ||
3. Install ml_downscaling_emulator locally: `pip install -e .` | ||
4. Install unet code: `git clone --depth 1 [email protected]:henryaddison/Pytorch-UNet src/ml_downscaling_emulator/unet` | ||
5. Configure necessary environment variables: `DERVIED_DATA` and `KK_SLACK_WH_URL` | ||
1. Clone repo and cd into it | ||
2. Create conda environment: `conda env create -f environment.lock.yml` (or add dependencies to your own: `conda env install -f environment.txt`) | ||
3. Activate the conda environment (if not already done so) | ||
4. Install ml_downscaling_emulator locally: `pip install -e .` | ||
5. Install unet code: `git clone --depth 1 [email protected]:henryaddison/Pytorch-UNet src/ml_downscaling_emulator/unet` | ||
6. Configure application behaviour with environment variables. See `.env.example` for variables that can be set. | ||
|
||
### Usage | ||
Any datasets are assumed to be found in `${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/`. In particular, the config key config.data.dataset_name is the name of the dataset to use to train the model. | ||
|
||
## Usage | ||
|
||
#### Smoke test | ||
### Smoke test | ||
|
||
`bin/local-test-train` | ||
```sh | ||
tests/smoke-test | ||
``` | ||
|
||
Uses a simpler network to test the full training and sampling regime. | ||
Recommended to run with a sample of the dataset. | ||
|
||
#### Training | ||
### Training | ||
|
||
Train models through `main.py`. | ||
Train models through `bin/main.py`, e.g. | ||
|
||
```sh | ||
python bin/main.py --config src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERVIED_DATA}/path/to/models/paper-12em --mode train | ||
``` | ||
|
||
```sh | ||
main.py: | ||
--mode: <train>: Running mode: train | ||
--workdir: Working directory for storing data related to model such as model snapshots, tranforms or samples | ||
--config: Training configuration. | ||
(default: 'None') | ||
--mode: <train|eval>: Running mode: train or eval | ||
--workdir: Working directory | ||
``` | ||
|
||
* `config` is the path to the config file. Our prescribed config files are provided in `configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and should be quite self-explanatory. | ||
* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta`. | ||
|
||
* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, transforms and samples. Recommended to be a subdirectory of ${DERIVED_DATA}. | ||
|
||
* `config` is the path to the config file. Config files for emulators are provided in `src/configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and heavily based on ncsnpp config files. | ||
|
||
**Naming conventions of config files**: the path of a config file is a combination of the following dimensions: | ||
* dataset: One of `cifar10`, `celeba`, `celebahq`, `celebahq_256`, `ffhq_256`, `celebahq`, `ffhq`. | ||
* model: One of `ncsn`, `ncsnv2`, `ncsnpp`, `ddpm`, `ddpmpp`. | ||
* SDE: `subvpsde` | ||
* data source: `ukcp_local` | ||
* variable: `pr` | ||
* ensemble members: `12em` (all 12) or `1em` (single) | ||
* model: `cncsnpp` | ||
* continuous: train the model with continuously sampled time steps. | ||
|
||
* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results. | ||
|
||
* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta` . | ||
|
||
These functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. For example, to generate samples and evaluate sample quality, supply the `--config.eval.enable_sampling` flag; to compute log-likelihoods, supply the `--config.eval.enable_bpd` flag, and specify `--config.eval.dataset=train/test` to indicate whether to compute the likelihoods on the training or test dataset. | ||
|
||
#### Sampling | ||
TODO | ||
|
||
## How to extend the code | ||
* **New SDEs**: inherent the `sde_lib.SDE` abstract class and implement all abstract methods. The `discretize()` method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE. | ||
* **New predictors**: inherent the `sampling.Predictor` abstract class, implement the `update_fn` abstract method, and register its name with `@register_predictor`. The new predictor can be directly used in `sampling.get_pc_sampler` for Predictor-Corrector sampling, and all other controllable generation methods in `controllable_generation.py`. | ||
* **New correctors**: inherent the `sampling.Corrector` abstract class, implement the `update_fn` abstract method, and register its name with `@register_corrector`. The new corrector can be directly used in `sampling.get_pc_sampler`, and all other controllable generation methods in `controllable_generation.py`. | ||
|
||
## Tips | ||
* When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via `config.training.n_jitted_steps`. For CIFAR-10, we recommend using `config.training.n_jitted_steps=5` when your GPU/TPU has sufficient memory; otherwise we recommend using `config.training.n_jitted_steps=1`. Our current implementation requires `config.training.log_freq` to be dividable by `n_jitted_steps` for logging and checkpointing to work normally. | ||
* The `snr` (signal-to-noise ratio) parameter of `LangevinCorrector` somewhat behaves like a temperature parameter. Larger `snr` typically results in smoother samples, while smaller `snr` gives more diverse but lower quality samples. Typical values of `snr` is `0.05 - 0.2`, and it requires tuning to strike the sweet spot. | ||
* For VE SDEs, we recommend choosing `config.model.sigma_max` to be the maximum pairwise distance between data samples in the training dataset. | ||
|
||
## References | ||
|
||
This code based on the following work: | ||
```bib | ||
@inproceedings{ | ||
song2021scorebased, | ||
title={Score-Based Generative Modeling through Stochastic Differential Equations}, | ||
author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole}, | ||
booktitle={International Conference on Learning Representations}, | ||
year={2021}, | ||
url={https://openreview.net/forum?id=PxTIG12RRHS} | ||
} | ||
``` | ||
Functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. | ||
|
||
|
||
### Sampling | ||
|
||
This work is built upon some previous papers which might also interest you: | ||
Once have trained a model create samples from it with `bin/predict.py`, e.g. | ||
|
||
```sh | ||
python bin/predict.py --checkpoint epoch-20 --dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --split test --ensemble-member 01 --input-transform-dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --input-transform-key pixelmmsstan --num-samples 1 ${DERVIED_DATA}/path/to/models/paper-12em | ||
``` | ||
|
||
* Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." *Proceedings of the 33rd Annual Conference on Neural Information Processing Systems*. 2019. | ||
* Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020. | ||
* Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020. | ||
This example command will: | ||
* use the checkpoint of the model in `${DERVIED_DATA}/path/to/models/paper-12em/checkpoints/{checkpoint}.pth` and model config from training `${DERVIED_DATA}/path/to/models/paper-12em/config.yml`. | ||
* store samples generated in `${DERVIED_DATA}/path/to/models/paper-12em/samples/{dataset}/{input_transform_data}-{input_transform_key}/{split}/{ensemble_member}/`. Sample files and named like `predictions-{uuid}.nc`. | ||
* generate samples conditioned on examples from ensemble member `01` in the `test` subset of the `bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season` dataset. | ||
* transform the inputs based on the `bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season` dataset using the `pixelmmsstan` approach. | ||
* generate 1 set of samples. |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.