From 44a3e2c427813fa22320350a26f051006c9c75b0 Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Mon, 18 Mar 2024 15:01:25 +0000 Subject: [PATCH] work on documenting how to train and sample from a model --- .env.example | 4 ++++ README.md | 68 +++++++++++++++++++++------------------------------- bin/main.py | 5 +--- 3 files changed, 32 insertions(+), 45 deletions(-) diff --git a/.env.example b/.env.example index 599de1a8b..04a999e83 100644 --- a/.env.example +++ b/.env.example @@ -1,2 +1,6 @@ +# [required] root path for where to find datasets and store models DERIVED_DATA=/path/to/derived_data +# [optional] log level LOG_LEVEL=INFO +# [optional] slack webhook url for training and samples notifications +KK_SLACK_WH_URL=https://hooks.slack.com diff --git a/README.md b/README.md index 05595c790..77fd1361a 100644 --- a/README.md +++ b/README.md @@ -11,68 +11,54 @@ by [Yang Song](https://yang-song.github.io), [Jascha Sohl-Dickstein](http://www. 3. Install unet code: `git clone --depth 1 git@github.com:henryaddison/Pytorch-UNet src/ml_downscaling_emulator/unet` 4. Configure necessary environment variables: `DERVIED_DATA` and `KK_SLACK_WH_URL` -### Usage +## Usage - -#### Smoke test +### Smoke test `bin/local-test-train` -#### Training +### Training -Train models through `main.py`. +Train models through `bin/main.py`, e.g. + +```sh +python bin/main.py --config src/ml_downscaling_emulator/score_sde_pytorch_hja22/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERVIED_DATA}/path/to/models/paper-12em --mode train +``` ```sh main.py: + --mode: : Running mode: train + --workdir: Working directory for storing data related to model such as model snapshots, tranforms or samples --config: Training configuration. (default: 'None') - --mode: : Running mode: train or eval - --workdir: Working directory ``` -* `config` is the path to the config file. Our prescribed config files are provided in `configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and should be quite self-explanatory. +* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta`. - **Naming conventions of config files**: the path of a config file is a combination of the following dimensions: - * dataset: One of `cifar10`, `celeba`, `celebahq`, `celebahq_256`, `ffhq_256`, `celebahq`, `ffhq`. - * model: One of `ncsn`, `ncsnv2`, `ncsnpp`, `ddpm`, `ddpmpp`. - * continuous: train the model with continuously sampled time steps. +* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, transforms and samples. Recommended to be a subdirectory of ${DERIVED_DATA}. -* `workdir` is the path that stores all artifacts of one experiment, like checkpoints, samples, and evaluation results. +* `config` is the path to the config file. Config files for emulators are provided in `src/configs/`. They are formatted according to [`ml_collections`](https://github.com/google/ml_collections) and heavily based on ncsnpp config files. -* `mode` is "train". When set to "train", it starts the training of a new model, or resumes the training of an old model if its meta-checkpoints (for resuming running after pre-emption in a cloud environment) exist in `workdir/checkpoints-meta` . + **Naming conventions of config files**: the path of a config file is a combination of the following dimensions: + * SDE: `subvpsde` + * data source: `ukcp_local` + * variable: `pr` + * ensemble members: `12em` (all 12) or `1em` (single) + * model: `cncsnpp` + * continuous: train the model with continuously sampled time steps. - These functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. For example, to generate samples and evaluate sample quality, supply the `--config.eval.enable_sampling` flag; to compute log-likelihoods, supply the `--config.eval.enable_bpd` flag, and specify `--config.eval.dataset=train/test` to indicate whether to compute the likelihoods on the training or test dataset. +Any datasets are assumed to be found in `${DERIVED_DATA}/moose/nc-datasets/{dataset_name}/`. In particular, the config key config.data.dataset_name is the name of the dataset to use to train the model. -#### Sampling -TODO +Functionalities can be configured through config files, or more conveniently, through the command-line support of the `ml_collections` package. -## How to extend the code -* **New SDEs**: inherent the `sde_lib.SDE` abstract class and implement all abstract methods. The `discretize()` method is optional and the default is Euler-Maruyama discretization. Existing sampling methods and likelihood computation will automatically work for this new SDE. -* **New predictors**: inherent the `sampling.Predictor` abstract class, implement the `update_fn` abstract method, and register its name with `@register_predictor`. The new predictor can be directly used in `sampling.get_pc_sampler` for Predictor-Corrector sampling, and all other controllable generation methods in `controllable_generation.py`. -* **New correctors**: inherent the `sampling.Corrector` abstract class, implement the `update_fn` abstract method, and register its name with `@register_corrector`. The new corrector can be directly used in `sampling.get_pc_sampler`, and all other controllable generation methods in `controllable_generation.py`. -## Tips -* When using the JAX codebase, you can jit multiple training steps together to improve training speed at the cost of more memory usage. This can be set via `config.training.n_jitted_steps`. For CIFAR-10, we recommend using `config.training.n_jitted_steps=5` when your GPU/TPU has sufficient memory; otherwise we recommend using `config.training.n_jitted_steps=1`. Our current implementation requires `config.training.log_freq` to be dividable by `n_jitted_steps` for logging and checkpointing to work normally. -* The `snr` (signal-to-noise ratio) parameter of `LangevinCorrector` somewhat behaves like a temperature parameter. Larger `snr` typically results in smoother samples, while smaller `snr` gives more diverse but lower quality samples. Typical values of `snr` is `0.05 - 0.2`, and it requires tuning to strike the sweet spot. -* For VE SDEs, we recommend choosing `config.model.sigma_max` to be the maximum pairwise distance between data samples in the training dataset. +### Sampling -## References +Once have trained a model create samples from it with `bin/predict.py`, e.g. -This code based on the following work: -```bib -@inproceedings{ - song2021scorebased, - title={Score-Based Generative Modeling through Stochastic Differential Equations}, - author={Yang Song and Jascha Sohl-Dickstein and Diederik P Kingma and Abhishek Kumar and Stefano Ermon and Ben Poole}, - booktitle={International Conference on Learning Representations}, - year={2021}, - url={https://openreview.net/forum?id=PxTIG12RRHS} -} +```sh +python bin/predict.py --checkpoint epoch-20 --num-samples 1 --dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --split test --input-transform-dataset bham_60km-4x_12em_psl-sphum4th-temp4th-vort4th_eqvt_random-season --input-transform-key pixelmmsstan --ensemble-member 01 ${DERVIED_DATA}/path/to/models/paper-12em ``` -This work is built upon some previous papers which might also interest you: - -* Song, Yang, and Stefano Ermon. "Generative Modeling by Estimating Gradients of the Data Distribution." *Proceedings of the 33rd Annual Conference on Neural Information Processing Systems*. 2019. -* Song, Yang, and Stefano Ermon. "Improved techniques for training score-based generative models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020. -* Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." *Proceedings of the 34th Annual Conference on Neural Information Processing Systems*. 2020. +This will use the checkpoint of the model in `${DERVIED_DATA}/path/to/models/paper-12em/checkpoints/{checkpoint}.pth` and model config from training `${DERVIED_DATA}/path/to/models/paper-12em/config.yml`. It will store samples generated in `${DERVIED_DATA}/path/to/models/paper-12em/samples/{dataset}/{input_transform_data}-{input_transform_key}/{split}/{ensemble_member}/`. Sample files and named like `predictions-{uuid}.nc`. diff --git a/bin/main.py b/bin/main.py index b7324173a..a08d2b801 100644 --- a/bin/main.py +++ b/bin/main.py @@ -31,10 +31,7 @@ "config", None, "Training configuration.", lock_config=True ) flags.DEFINE_string("workdir", None, "Work directory.") -flags.DEFINE_enum("mode", None, ["train", "eval"], "Running mode: train or eval") -flags.DEFINE_string( - "eval_folder", "eval", "The folder name for storing evaluation results" -) +flags.DEFINE_enum("mode", None, ["train"], "Running mode: train") flags.mark_flags_as_required(["workdir", "config", "mode"])