An implementation of Elucidating the Design Space of Diffusion-Based Generative Models (Karras et al., 2022) for PyTorch. The patching method in Improving Diffusion Model Efficiency Through Patching is implemented as well.
To train models:
$ ./train.py --config CONFIG_FILE --name RUN_NAME
For instance, to train a model on MNIST:
$ ./train.py --config configs/config_mnist.json --name RUN_NAME
The configuration file allows you to specify the dataset type. Currently supported types are "imagefolder"
(finds all images in that folder and its subfolders, recursively), "cifar10"
(CIFAR-10), and "mnist"
(MNIST). "huggingface"
Hugging Face Datasets is also supported.
Multi-GPU and multi-node training is supported with Hugging Face Accelerate. You can configure Accelerate by running:
$ accelerate config
on all nodes, then running:
$ accelerate launch train.py --config CONFIG_FILE --name RUN_NAME
on all nodes.
-
k-diffusion supports an experimental model output type, an isotropic Gaussian, which seems to have a lower gradient noise scale and to train faster than Karras et al. (2022) diffusion models.
-
k-diffusion has wrappers for v-diffusion-pytorch, OpenAI diffusion, and CompVis diffusion models allowing them to be used with its samplers and ODE/SDE.
-
k-diffusion models support progressive growing.
-
k-diffusion implements a sampler inspired by DPM-Solver and Karras et al. (2022) Algorithm 2 that produces higher quality samples at the same number of function evalutions as Karras Algorithm 2. It also implements a linear multistep sampler (comparable to PLMS).
-
k-diffusion supports CLIP guided sampling from unconditional diffusion models (see
sample_clip_guided.py
). -
k-diffusion supports log likelihood calculation (not a variational lower bound) for native models and all wrapped models.
-
k-diffusion can calculate, during training, the FID and KID vs the training set.
-
k-diffusion can calculate, during training, the gradient noise scale (1 / SNR), from An Empirical Model of Large-Batch Training, https://arxiv.org/abs/1812.06162).
-
Anything except unconditional image diffusion models
-
Latent diffusion