Post-training experimentations on diffusion models using policy gradient methods. Explore ideas with pretrained, non-conditional DDPM models. These models, such as google/ddpm-celebahq-256 and google/ddpm-church-256, are based on Denoising Diffusion Probabilistic Models (Ho et al., 2020).
This project support the three main downstream tasks use to align the image generation in the Training Diffusion Models with Reinforcement Learning (Black et al. 2023) work:
Aesthetic Quality |
JPEG Compressibility |
JPEG Incompressibility |
A visual comparison between DDPM samples (i.e., pretrained) and DDPO samples generated from the same initial noise. The DDPO samples utilize checkpoints finetuned with their corresponding reward functions.
- Roughly 6% of the samples generated by the google/ddpm-celebahq-256 model are 50 ≥ years old
- Can we stress RL to generate more samples of this kind?
- Goal: Increasing the Frequency of Generated Celebrity-Like Faces Over 50 Years Old
- ViT Age Classifier (Nate Raw, 2021), trained on the FairFace dataset, to predict the age of the samples.
def over50_old(
device: str = "cuda",
) -> Callable[[Any], torch.Tensor]:
"""Calculate the rewards for images with probabilities over 50 years old."""
from transformers import ViTForImageClassification, ViTImageProcessor
model = ViTForImageClassification.from_pretrained("nateraw/vit-age-classifier")
transforms = ViTImageProcessor.from_pretrained("nateraw/vit-age-classifier")
model.to(device)
model.eval()
def _fn(images):
inputs = transforms(
decode_tensor_to_np_img(
images,
melt_batch=False,
),
return_tensors="pt",
).pixel_values.cuda()
with torch.no_grad():
outputs = model(inputs).logits
return outputs[:, 6:].sum(dim=1)
return _fn
DDPM age distribution based on faces |
Post-training distribution with OVER50 |
Experiment | Model (Hugging Face) | W&B |
---|---|---|
google/ddpm-celebahq-256 | ||
Aesthetic Quality | aesthetic-celebahq-256 | run1/run2 |
Compressibility | compressibility-celebahq-256 | run1/run2 |
Incompressibility | incompressibility-celebahq-256 | run1/run2 |
OVER50 | over50-celebahq-256 | run1/run2/run3/run4/run5/run6 |
google/ddpm-church-256 | ||
Aesthetic Quality | aesthetic-church-256 | run1/run2 |
Compressibility | compressibility-church-256 | run1/run2/run3 |
Incompressibility | incompressibility-church-256 | run1/run2/run3 |
Note: Multiple runs indicate that the experiment continued training from the previous run, using the last saved checkpoint.
For setting the project.
git clone [email protected]:alcazar90/ddpo-celebahq.git
cd ddpo-celebahq
pip install -e .
Example for Running the training script:
python ./ddpo-celebahq/scripts/train.py \
--wandb_logging \
--task "aesthetic score" \
--initial_lr 0.00000009 \
--peak_lr 0.00000374 \
--warmup_pct 0.5 \
--num_samples_per_epoch 100 \
--batch_size 10 \
--num_epochs 25 \
--clip_advantages 10 \
--num_inner_epochs 1 \
--eval_every_each_epoch 1 \
--num_eval_samples 64 \
--run_seed 92013491249214123 \
--eval_rnd_seed 650 \
--save_model \
--ddpm_ckpt google/ddpm-church-256
For clone this repo, install dependencies, and running the training script in a Google Colab instance with GPU, follow this colab as example.