Skip to content

alcazar90/ddpo-celebahq

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

  Finetuned Diffusion Models
using Reinforcement Learning
👾

Post-training experimentations on diffusion models using policy gradient methods. Explore ideas with pretrained, non-conditional DDPM models. These models, such as google/ddpm-celebahq-256 and google/ddpm-church-256, are based on Denoising Diffusion Probabilistic Models (Ho et al., 2020).

Reward finetuning diagram Reward Finetuning is a two-step process to align a pre-trained diffusion model (DDPM) with downstream tasks using reward-based optimization, known as DDPO (Black et al., 2023). In the first step, the diffusion model generates samples (trajectories), which are evaluated by a reward model acting as an oracle to determine which samples to incentivize. In the second step, gradients are estimated via Monte Carlo methods based on the collected dataset to update the diffusion model parameters.

Downstream Tasks

This project support the three main downstream tasks use to align the image generation in the Training Diffusion Models with Reinforcement Learning (Black et al. 2023) work:

Aesthetic Quality Transition
Aesthetic Quality
Compressibiliity Transition
JPEG Compressibility
Incompressibility Transition
JPEG Incompressibility

A visual comparison between DDPM samples (i.e., pretrained) and DDPO samples generated from the same initial noise. The DDPO samples utilize checkpoints finetuned with their corresponding reward functions.

Visual Comparison between DDPM and DDPO on different downstream tasks Visual comparison of DDPM and DDPO samples optimized for various downstream tasks.

OVER50: Use an off-the-shelf classifier to design the reward function,

  • Roughly 6% of the samples generated by the google/ddpm-celebahq-256 model are 50 ≥ years old
  • Can we stress RL to generate more samples of this kind?
  • Goal: Increasing the Frequency of Generated Celebrity-Like Faces Over 50 Years Old
  • ViT Age Classifier (Nate Raw, 2021), trained on the FairFace dataset, to predict the age of the samples.
Reward finetuning diagram Using a ViT Age Classifier to design the OVER50 reward. The reward is computed as the sum of the logits for the relevant age classes, incentivizing samples wiith a hiigher likelihood of depiicting a faces over 50 years old.
def over50_old(
    device: str = "cuda",
) -> Callable[[Any], torch.Tensor]:
    """Calculate the rewards for images with probabilities over 50 years old."""
    from transformers import ViTForImageClassification, ViTImageProcessor

    model = ViTForImageClassification.from_pretrained("nateraw/vit-age-classifier")
    transforms = ViTImageProcessor.from_pretrained("nateraw/vit-age-classifier")
    model.to(device)
    model.eval()

    def _fn(images):
        inputs = transforms(
            decode_tensor_to_np_img(
                images,
                melt_batch=False,
            ),
            return_tensors="pt",
        ).pixel_values.cuda()
        with torch.no_grad():
            outputs = model(inputs).logits
        return outputs[:, 6:].sum(dim=1)

    return _fn
Aesthetic Quality Transition
DDPM age distribution
based on faces
Incompressibility Transition
Post-training distribution
with OVER50

Experiment Details 🧪

Experiment Model (Hugging Face) W&B
google/ddpm-celebahq-256
Aesthetic Quality aesthetic-celebahq-256 run1/run2
Compressibility compressibility-celebahq-256 run1/run2
Incompressibility incompressibility-celebahq-256 run1/run2
OVER50 over50-celebahq-256 run1/run2/run3/run4/run5/run6
google/ddpm-church-256
Aesthetic Quality aesthetic-church-256 run1/run2
Compressibility compressibility-church-256 run1/run2/run3
Incompressibility incompressibility-church-256 run1/run2/run3

Note: Multiple runs indicate that the experiment continued training from the previous run, using the last saved checkpoint.

Overoptimization and mode collapse

Reward finetuning diagram Comparison of Image Synthesis Using CelebA-HQ-Based Models. 2D projection of CLIP embeddings for two sets of 1,000 samples: i) DDPM samples (black borders) and ii) DDPO samples fine-tuned with the LAION aesthetic reward (white borders). The DDPO samples were optimized to achieve a higher average aesthetic score (5.58 vs. 5.11), indicating better aesthetic quality. Notably, the DDPO samples cluster more tightly (red ellipse) around the highest-scoring DDPM sample, indicating a mode collapse effect. Both sets of samples were generated using the same seed.

Getting Started

Open In Colab

For setting the project.

git clone [email protected]:alcazar90/ddpo-celebahq.git
cd ddpo-celebahq
pip install -e .

Example for Running the training script:

python ./ddpo-celebahq/scripts/train.py \
--wandb_logging \
--task "aesthetic score" \
--initial_lr 0.00000009 \
--peak_lr 0.00000374 \
--warmup_pct 0.5 \
--num_samples_per_epoch 100 \
--batch_size 10 \
--num_epochs 25 \
--clip_advantages 10 \
--num_inner_epochs 1 \
--eval_every_each_epoch 1 \
--num_eval_samples 64 \
--run_seed 92013491249214123 \
--eval_rnd_seed 650  \
--save_model \
--ddpm_ckpt google/ddpm-church-256

For clone this repo, install dependencies, and running the training script in a Google Colab instance with GPU, follow this colab as example.

Releases

No releases published

Packages

No packages published

Languages