Skip to content

Commit

Permalink
Added DiT examples
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanchenyang committed Apr 24, 2024
1 parent 57cd336 commit 8ba1ddc
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 22 deletions.
53 changes: 31 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,34 @@ losses = [ns.loss.item() for ns in trainer]
Results on various toy datasets:

<p align="center">
<img src="https://github.com/yuanchenyang/smalldiffusion/blob/main/imgs/toy_models.png" width=100%>
<img src="https://raw.githubusercontent.com/yuanchenyang/smalldiffusion/main/imgs/toy_models.png" width=100%>
</p>

### U-Net models
The same code can be used to train [U-Net-based models][unet-py]. To
train a model on the FashionMNIST dataset and generate a batch of samples (after
first running `accelerate config`):
### Diffusion Transformer
We provide [a concise implementation][model-code] of the diffusion transformer introduced in
[[Peebles and Xie 2022]][dit-paper]. To train a model on the FashionMNIST dataset and
generate a batch of samples (after first running `accelerate config`):

```
accelerate launch examples/fashion_mnist.py
accelerate launch examples/fashion_mnist_dit.py
```

With the provided default parameters and training on a single GPU for around 2
hours, the model can achieve a [FID
score](https://paperswithcode.com/sota/image-generation-on-fashion-mnist) of
around 12-13, producing the following generated outputs:
around 5-6, producing the following generated outputs:

<p align="center">
<img src="https://github.com/yuanchenyang/smalldiffusion/blob/main/imgs/fashion-mnist-samples.png" width=50%>
<img src="https://raw.githubusercontent.com/yuanchenyang/smalldiffusion/main/imgs/fashion-mnist-samples.png" width=50%>
</p>

### U-Net models
The same code can be used to train [U-Net-based models][unet-py].

```
accelerate launch examples/fashion_mnist_unet.py
```

### StableDiffusion
smalldiffusion's sampler works with any pretrained diffusion model, and supports
DDPM, DDIM as well as accelerated sampling algorithms. In
Expand All @@ -82,7 +89,7 @@ schedules, as demonstrated in [examples/stablediffusion.py][stablediffusion]. A
few examples on tweaking the parameter `gam`:

<p align="center">
<img src="https://github.com/yuanchenyang/smalldiffusion/blob/main/imgs/sd_examples.jpg" width=100%>
<img src="https://raw.githubusercontent.com/yuanchenyang/smalldiffusion/main/imgs/sd_examples.jpg" width=100%>
</p>


Expand Down Expand Up @@ -141,7 +148,7 @@ Three schedules are provided:

The following plot shows these three schedules with default parameters.
<p align="center">
<img src="https://github.com/yuanchenyang/smalldiffusion/blob/main/imgs/schedule.png" width=40%>
<img src="https://raw.githubusercontent.com/yuanchenyang/smalldiffusion/main/imgs/schedule.png" width=40%>
</p>

### Training
Expand Down Expand Up @@ -174,15 +181,17 @@ and implemented in only 5 lines of code, see Appendix A of [[Permenter and
Yuan]][arxiv-url].


[diffusion-py]: https://github.com/yuanchenyang/smalldiffusion/blob/main/src/smalldiffusion/diffusion.py
[unet-py]: https://github.com/yuanchenyang/smalldiffusion/blob/main/examples/unet.py
[diffusers-wrapper]: https://github.com/yuanchenyang/smalldiffusion/blob/main/examples/diffusers_wrapper.py
[stablediffusion]: https://github.com/yuanchenyang/smalldiffusion/blob/main/examples/stablediffusion.py
[build-img]: https://github.com/yuanchenyang/smalldiffusion/workflows/CI/badge.svg
[build-url]: https://github.com/yuanchenyang/smalldiffusion/actions?query=workflow%3ACI
[pypi-img]: https://img.shields.io/badge/pypi-blue
[pypi-url]: https://pypi.org/project/smalldiffusion/
[blog-img]: https://img.shields.io/badge/Tutorial-blogpost-blue
[blog-url]: https://www.chenyang.co/diffusion.html
[arxiv-img]: https://img.shields.io/badge/Paper-arxiv-blue
[arxiv-url]: https://arxiv.org/abs/2306.04848
[diffusion-py]:https://github.com/yuanchenyang/smalldiffusion/blob/main/src/smalldiffusion/diffusion.py
[unet-py]:https://github.com/yuanchenyang/smalldiffusion/blob/main/examples/unet.py
[diffusers-wrapper]:https://github.com/yuanchenyang/smalldiffusion/blob/main/examples/diffusers_wrapper.py
[stablediffusion]:https://github.com/yuanchenyang/smalldiffusion/blob/main/examples/stablediffusion.py
[build-img]:https://github.com/yuanchenyang/smalldiffusion/workflows/CI/badge.svg
[build-url]:https://github.com/yuanchenyang/smalldiffusion/actions?query=workflow%3ACI
[pypi-img]:https://img.shields.io/badge/pypi-blue
[pypi-url]:https://pypi.org/project/smalldiffusion/
[dit-paper]:https://arxiv.org/abs/2212.09748
[model-code]:https://github.com/yuanchenyang/smalldiffusion/blob/main/src/smalldiffusion/model.py
[blog-img]:https://img.shields.io/badge/Tutorial-blogpost-blue
[blog-url]:https://www.chenyang.co/diffusion.html
[arxiv-img]:https://img.shields.io/badge/Paper-arxiv-blue
[arxiv-url]:https://arxiv.org/abs/2306.04848
38 changes: 38 additions & 0 deletions examples/fashion_mnist_dit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
from accelerate import Accelerator
from torch.utils.data import DataLoader
from torchvision import transforms as tf
from torchvision.datasets import FashionMNIST
from torchvision.utils import make_grid, save_image
from torch_ema import ExponentialMovingAverage as EMA
from tqdm import tqdm

from smalldiffusion import ScheduleDDPM, samples, training_loop, MappedDataset, DiT

# Setup
accelerator = Accelerator()
dataset = MappedDataset(FashionMNIST('datasets', train=True, download=True,
transform=tf.Compose([
tf.RandomHorizontalFlip(),
tf.ToTensor(),
tf.Lambda(lambda t: (t * 2) - 1)
])),
lambda x: x[0])
loader = DataLoader(dataset, batch_size=1024, shuffle=True)
schedule = ScheduleDDPM(beta_start=0.0001, beta_end=0.02, N=1000)
model = DiT(in_dim=28, channels=1,
patch_size=2, depth=6, head_dim=32, num_heads=6, mlp_ratio=4.0)

# Train
trainer = training_loop(loader, model, schedule, epochs=300, lr=1e-3, accelerator=accelerator)
ema = EMA(model.parameters(), decay=0.99)
ema.to(accelerator.device)
for ns in trainer:
ns.pbar.set_description(f'Loss={ns.loss.item():.5}')
ema.update()

# Sample
with ema.average_parameters():
*xt, x0 = samples(model, schedule.sample_sigmas(20), gam=1.6, batchsize=64, accelerator=accelerator)
save_image(((make_grid(x0) + 1)/2).clamp(0, 1), 'fashion_mnist_samples.png')
torch.save(model.state_dict(), 'checkpoint.pth')
File renamed without changes.
Binary file modified imgs/fashion-mnist-samples.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 8ba1ddc

Please sign in to comment.