forked from yang-song/score_sde_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
work on documenting how to train and sample from a model
- Loading branch information
1 parent
89f92f3
commit 62946f2
Showing
3 changed files
with
32 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
# [required] root path for where to find datasets and store models | ||
DERIVED_DATA=/path/to/derived_data | ||
# [optional] log level | ||
LOG_LEVEL=INFO | ||
# [optional] slack webhook url for training and samples notifications | ||
KK_SLACK_WH_URL=https://hooks.slack.com |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,68 +11,54 @@ by [Yang Song](https://yang-song.github.io), [Jascha Sohl-Dickstein](http://www. | |
3. Install unet code: `git clone --depth 1 [email protected]:henryaddison/Pytorch-UNet src/ml_downscaling_emulator/unet` | ||
4. Configure necessary environment variables: `DERVIED_DATA` and `KK_SLACK_WH_URL` | ||
|
||
### Usage | ||
## Usage | ||
|
||
|
||
#### Smoke test | ||
### Smoke test | ||
|
||
`bin/local-test-train` | ||
|
||
|
||
#### Training | ||
### Training | ||
|
||
Train models through `main.py`. | ||
Train models through `bin/main.py`, e.g. | ||
|
||
```sh | ||
python bin/main.py --config src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERVIED_DATA}/path/to/models/paper-12em --mode train | ||
``` | ||
|
||
```sh | ||
main.py: | ||
--mode: <train>: Running mode: train | ||
--workdir: Working directory for storing data related to model such as model snapshots, tranforms or samples | ||
--config: Training configuration. | ||
(default: 'None') | ||
--mode: <train|eval>: Running mode: train or eval | ||
--workdir: Working directory | ||
``` | ||
|
||
* `config` is the path to the config file. Our prescribed config files are provided in `configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and should be quite self-explanatory. | ||
* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta`. | ||
|
||
**Naming conventions of config files**: the path of a config file is a combination of the following dimensions: | ||
* dataset: One of `cifar10`, `celeba`, `celebahq`, `celebahq_256`, `ffhq_256`, `celebahq`, `ffhq`. | ||
* model: One of `ncsn`, `ncsnv2`, `ncsnpp`, `ddpm`, `ddpmpp`. | ||
* continuous: train the model with continuously sampled time steps. | ||
* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, transforms and samples. Recommended to be a subdirectory of ${DERIVED_DATA}. | ||
|
||
* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results. | ||
* `config` is the path to the config file. Config files for emulators are provided in `src/configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and heavily based on ncsnpp config files. | ||
|
||
* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta` . | ||
**Naming conventions of config files**: the path of a config file is a combination of the following dimensions: | ||
* SDE: `subvpsde` | ||
* data source: `ukcp_local` | ||
* variable: `pr` | ||
* ensemble members: `12em` (all 12) or `1em` (single) | ||
* model: `cncsnpp` | ||
* continuous: train the model with continuously sampled time steps. | ||
|
||
These functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. For example, to generate samples and evaluate sample quality, supply the `--config.eval.enable_sampling` flag; to compute log-likelihoods, supply the `--config.eval.enable_bpd` flag, and specify `--config.eval.dataset=train/test` to indicate whether to compute the likelihoods on the training or test dataset. | ||
Any datasets are assumed to be found in `${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/`. In particular, the config key config.data.dataset_name is the name of the dataset to use to train the model. | ||
|
||
#### Sampling | ||
TODO | ||
Functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. | ||
|
||
## How to extend the code | ||
* **New SDEs**: inherent the `sde_lib.SDE` abstract class and implement all abstract methods. The `discretize()` method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE. | ||
* **New predictors**: inherent the `sampling.Predictor` abstract class, implement the `update_fn` abstract method, and register its name with `@register_predictor`. The new predictor can be directly used in `sampling.get_pc_sampler` for Predictor-Corrector sampling, and all other controllable generation methods in `controllable_generation.py`. | ||
* **New correctors**: inherent the `sampling.Corrector` abstract class, implement the `update_fn` abstract method, and register its name with `@register_corrector`. The new corrector can be directly used in `sampling.get_pc_sampler`, and all other controllable generation methods in `controllable_generation.py`. | ||
|
||
## Tips | ||
* When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via `config.training.n_jitted_steps`. For CIFAR-10, we recommend using `config.training.n_jitted_steps=5` when your GPU/TPU has sufficient memory; otherwise we recommend using `config.training.n_jitted_steps=1`. Our current implementation requires `config.training.log_freq` to be dividable by `n_jitted_steps` for logging and checkpointing to work normally. | ||
* The `snr` (signal-to-noise ratio) parameter of `LangevinCorrector` somewhat behaves like a temperature parameter. Larger `snr` typically results in smoother samples, while smaller `snr` gives more diverse but lower quality samples. Typical values of `snr` is `0.05 - 0.2`, and it requires tuning to strike the sweet spot. | ||
* For VE SDEs, we recommend choosing `config.model.sigma_max` to be the maximum pairwise distance between data samples in the training dataset. | ||
### Sampling | ||
|
||
## References | ||
Once have trained a model create samples from it with `bin/predict.py`, e.g. | ||
|
||
This code based on the following work: | ||
```bib | ||
@inproceedings{ | ||
song2021scorebased, | ||
title={Score-Based Generative Modeling through Stochastic Differential Equations}, | ||
author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole}, | ||
booktitle={International Conference on Learning Representations}, | ||
year={2021}, | ||
url={https://openreview.net/forum?id=PxTIG12RRHS} | ||
} | ||
```sh | ||
python bin/predict.py --checkpoint epoch-20 --num-samples 1 --dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --split test --input-transform-dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --input-transform-key pixelmmsstan --ensemble-member 01 ${DERVIED_DATA}/path/to/models/paper-12em | ||
``` | ||
|
||
This work is built upon some previous papers which might also interest you: | ||
|
||
* Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." *Proceedings of the 33rd Annual Conference on Neural Information Processing Systems*. 2019. | ||
* Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020. | ||
* Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020. | ||
This will use the checkpoint of the model in `${DERVIED_DATA}/path/to/models/paper-12em/checkpoints/{checkpoint}.pth` and model config from training `${DERVIED_DATA}/path/to/models/paper-12em/config.yml`. It will store samples generated in `${DERVIED_DATA}/path/to/models/paper-12em/samples/{dataset}/{input_transform_data}-{input_transform_key}/{split}/{ensemble_member}/`. Sample files and named like `predictions-{uuid}.nc`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters