Skip to content

Commit

Permalink
work on documenting how to train and sample from a model
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Mar 18, 2024
1 parent fc4ed24 commit 44a3e2c
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 45 deletions.
4 changes: 4 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,6 @@
# [required] root path for where to find datasets and store models
DERIVED_DATA=/path/to/derived_data
# [optional] log level
LOG_LEVEL=INFO
# [optional] slack webhook url for training and samples notifications
KK_SLACK_WH_URL=https://hooks.slack.com
68 changes: 27 additions & 41 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,68 +11,54 @@ by [Yang Song](https://yang-song.github.io), [Jascha Sohl-Dickstein](http://www.
3. Install unet code: `git clone --depth 1 [email protected]:henryaddison/Pytorch-UNet src/ml_downscaling_emulator/unet`
4. Configure necessary environment variables: `DERVIED_DATA` and `KK_SLACK_WH_URL`

### Usage
## Usage


#### Smoke test
### Smoke test

`bin/local-test-train`


#### Training
### Training

Train models through `main.py`.
Train models through `bin/main.py`, e.g.

```sh
python bin/main.py --config src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERVIED_DATA}/path/to/models/paper-12em --mode train
```

```sh
main.py:
--mode: <train>: Running mode: train
--workdir: Working directory for storing data related to model such as model snapshots, tranforms or samples
--config: Training configuration.
(default: 'None')
--mode: <train|eval>: Running mode: train or eval
--workdir: Working directory
```

* `config` is the path to the config file. Our prescribed config files are provided in `configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and should be quite self-explanatory.
* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta`.

**Naming conventions of config files**: the path of a config file is a combination of the following dimensions:
* dataset: One of `cifar10`, `celeba`, `celebahq`, `celebahq_256`, `ffhq_256`, `celebahq`, `ffhq`.
* model: One of `ncsn`, `ncsnv2`, `ncsnpp`, `ddpm`, `ddpmpp`.
* continuous: train the model with continuously sampled time steps.
* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, transforms and samples. Recommended to be a subdirectory of ${DERIVED_DATA}.

* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results.
* `config` is the path to the config file. Config files for emulators are provided in `src/configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and heavily based on ncsnpp config files.

* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta` .
**Naming conventions of config files**: the path of a config file is a combination of the following dimensions:
* SDE: `subvpsde`
* data source: `ukcp_local`
* variable: `pr`
* ensemble members: `12em` (all 12) or `1em` (single)
* model: `cncsnpp`
* continuous: train the model with continuously sampled time steps.

These functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. For example, to generate samples and evaluate sample quality, supply the `--config.eval.enable_sampling` flag; to compute log-likelihoods, supply the `--config.eval.enable_bpd` flag, and specify `--config.eval.dataset=train/test` to indicate whether to compute the likelihoods on the training or test dataset.
Any datasets are assumed to be found in `${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/`. In particular, the config key config.data.dataset_name is the name of the dataset to use to train the model.

#### Sampling
TODO
Functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package.

## How to extend the code
* **New SDEs**: inherent the `sde_lib.SDE` abstract class and implement all abstract methods. The `discretize()` method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE.
* **New predictors**: inherent the `sampling.Predictor` abstract class, implement the `update_fn` abstract method, and register its name with `@register_predictor`. The new predictor can be directly used in `sampling.get_pc_sampler` for Predictor-Corrector sampling, and all other controllable generation methods in `controllable_generation.py`.
* **New correctors**: inherent the `sampling.Corrector` abstract class, implement the `update_fn` abstract method, and register its name with `@register_corrector`. The new corrector can be directly used in `sampling.get_pc_sampler`, and all other controllable generation methods in `controllable_generation.py`.

## Tips
* When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via `config.training.n_jitted_steps`. For CIFAR-10, we recommend using `config.training.n_jitted_steps=5` when your GPU/TPU has sufficient memory; otherwise we recommend using `config.training.n_jitted_steps=1`. Our current implementation requires `config.training.log_freq` to be dividable by `n_jitted_steps` for logging and checkpointing to work normally.
* The `snr` (signal-to-noise ratio) parameter of `LangevinCorrector` somewhat behaves like a temperature parameter. Larger `snr` typically results in smoother samples, while smaller `snr` gives more diverse but lower quality samples. Typical values of `snr` is `0.05 - 0.2`, and it requires tuning to strike the sweet spot.
* For VE SDEs, we recommend choosing `config.model.sigma_max` to be the maximum pairwise distance between data samples in the training dataset.
### Sampling

## References
Once have trained a model create samples from it with `bin/predict.py`, e.g.

This code based on the following work:
```bib
@inproceedings{
song2021scorebased,
title={Score-Based Generative Modeling through Stochastic Differential Equations},
author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=PxTIG12RRHS}
}
```sh
python bin/predict.py --checkpoint epoch-20 --num-samples 1 --dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --split test --input-transform-dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --input-transform-key pixelmmsstan --ensemble-member 01 ${DERVIED_DATA}/path/to/models/paper-12em
```

This work is built upon some previous papers which might also interest you:

* Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." *Proceedings of the 33rd Annual Conference on Neural Information Processing Systems*. 2019.
* Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020.
* Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020.
This will use the checkpoint of the model in `${DERVIED_DATA}/path/to/models/paper-12em/checkpoints/{checkpoint}.pth` and model config from training `${DERVIED_DATA}/path/to/models/paper-12em/config.yml`. It will store samples generated in `${DERVIED_DATA}/path/to/models/paper-12em/samples/{dataset}/{input_transform_data}-{input_transform_key}/{split}/{ensemble_member}/`. Sample files and named like `predictions-{uuid}.nc`.
5 changes: 1 addition & 4 deletions bin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
"config", None, "Training configuration.", lock_config=True
)
flags.DEFINE_string("workdir", None, "Work directory.")
flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval")
flags.DEFINE_string(
"eval_folder", "eval", "The folder name for storing evaluation results"
)
flags.DEFINE_enum("mode", None, ["train"], "Running mode: train")
flags.mark_flags_as_required(["workdir", "config", "mode"])


Expand Down

0 comments on commit 44a3e2c

Please sign in to comment.