Skip to content

Commit

Permalink
use zero2 for train (i love gan)
Browse files Browse the repository at this point in the history
  • Loading branch information
oahzxl committed Feb 18, 2024
1 parent 7b67fe8 commit 29335b8
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 33 deletions.
6 changes: 4 additions & 2 deletions dit/models/dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def timestep_embedding(t, dim, max_period=10000):
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding

def forward(self, t):
def forward(self, t, dtype):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
if t_freq.dtype != dtype:
t_freq = t_freq.to(dtype)
t_emb = self.mlp(t_freq)
return t_emb

Expand Down Expand Up @@ -240,7 +242,7 @@ def forward(self, x, t, y):
y: (N,) tensor of class labels
"""
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
t = self.t_embedder(t) # (N, D)
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
c = t + y # (N, D)
for block in self.blocks:
Expand Down
169 changes: 138 additions & 31 deletions dit/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,59 @@
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
import argparse
import json
import logging
import os
from collections import OrderedDict
from copy import deepcopy
from glob import glob
from time import time

# from torch.utils.data import DataLoader
# from torch.utils.data.distributed import DistributedSampler
# from torchvision.datasets import ImageFolder
# from torchvision import transforms
import numpy as np
import torch.distributed as dist
import tqdm
from diffusers.models import AutoencoderKL
from models.diffusion import create_diffusion
from models.dit import DiT_models
from PIL import Image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device

#################################################################################
# Training Helper Functions #
#################################################################################


def get_model_numel(model: torch.nn.Module) -> int:
return sum(p.numel() for p in model.parameters())


def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f"{numel / B:.2f} B"
elif numel >= M:
return f"{numel / M:.2f} M"
elif numel >= K:
return f"{numel / K:.2f} K"
else:
return f"{numel}"


def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor


@torch.no_grad()
def update_ema(ema_model, model, decay=0.9999):
"""
Expand Down Expand Up @@ -112,43 +140,112 @@ def main(args):
"""
assert torch.cuda.is_available(), "Training currently requires at least one GPU."

# Setup DDP:
dist.init_process_group("nccl")
assert args.global_batch_size % dist.get_world_size() == 0, f"Batch size must be divisible by world size."
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
rank = dist.get_rank()
device = rank % torch.cuda.device_count()
seed = args.global_seed * dist.get_world_size() + rank
torch.manual_seed(seed)
torch.cuda.set_device(device)
print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")

# Setup an experiment folder:
if rank == 0:
device = get_current_device()

# ==============================
# Setup an experiment folder
# ==============================
if coordinator.is_master():
os.makedirs(args.results_dir, exist_ok=True) # Make results folder (holds all experiment subfolders)
experiment_index = len(glob(f"{args.results_dir}/*"))
model_string_name = args.model.replace("/", "-") # e.g., DiT-XL/2 --> DiT-XL-2 (for naming folders)
experiment_dir = f"{args.results_dir}/{experiment_index:03d}-{model_string_name}" # Create an experiment folder
checkpoint_dir = f"{experiment_dir}/checkpoints" # Stores saved model checkpoints
os.makedirs(checkpoint_dir, exist_ok=True)
with open(f"{experiment_dir}/config.txt", "w") as f:
json.dump(args.__dict__, f, indent=4)
logger = create_logger(experiment_dir)
logger.info(f"Experiment directory created at {experiment_dir}")
else:
logger = create_logger(None)

# Create model:
# ==============================
# Initialize Tensorboard
# ==============================
if coordinator.is_master():
tensorboard_dir = f"{experiment_dir}/tensorboard"
os.makedirs(tensorboard_dir, exist_ok=True)
SummaryWriter(tensorboard_dir)

# ==============================
# Initialize Booster
# ==============================
if args.plugin == "gemini":
plugin = GeminiPlugin(
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
precision=args.mixed_precision,
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip,
)
elif args.plugin == "zero2_cpu":
plugin = LowLevelZeroPlugin(
stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip,
)
elif args.plugin == "3d":
plugin = HybridParallelPlugin(
tp_size=args.tp,
pp_size=1,
zero_stage=args.zero,
max_norm=args.grad_clip,
precision=args.mixed_precision,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
booster = Booster(plugin=plugin)

# ======================================================
# Initialize Model, Objective, Optimizer
# ======================================================
# Create model
assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
latent_size = args.image_size // 8
model = DiT_models[args.model](input_size=latent_size, num_classes=args.num_classes)
dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
model = DiT_models[args.model](input_size=latent_size, num_classes=args.num_classes).to(device).to(dtype)
model_numel = get_model_numel(model)
logger.info(f"Model params: {format_numel_str(model_numel)}")

# Create ema and vae model
# Note that parameter initialization is done within the DiT constructor
ema = deepcopy(model).to(device) # Create an EMA of the model for use after training
requires_grad(ema, False)
model = DDP(model.to(device), device_ids=[rank])
diffusion = create_diffusion(timestep_respacing="") # default: 1000 steps, linear noise schedule
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)
logger.info(f"DiT Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
opt = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0)
# Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper)
optimizer = HybridAdam(model.parameters(), lr=args.lr, weight_decay=0, adamw_mode=True)

# Prepare models for training
update_ema(ema, model, decay=0) # Ensure EMA is initialized with synced weights
model.train() # important! This enables embedding dropout for classifier-free guidance
ema.eval() # EMA model should always be in eval mode

# Boost model for distributed training
torch.set_default_dtype(dtype)
model, optimizer, _, _, _ = booster.boost(model=model, optimizer=optimizer)
torch.set_default_dtype(torch.float)

# Setup data:
# transform = transforms.Compose([
Expand Down Expand Up @@ -176,11 +273,6 @@ def main(args):
# )
# logger.info(f"Dataset contains {len(dataset):,} images ({args.data_path})")

# Prepare models for training:
update_ema(ema, model.module, decay=0) # Ensure EMA is initialized with synced weights
model.train() # important! This enables embedding dropout for classifier-free guidance
ema.eval() # EMA model should always be in eval mode

# Variables for monitoring/logging purposes:
train_steps = 0
log_steps = 0
Expand All @@ -189,6 +281,8 @@ def main(args):

batch = int(args.global_batch_size // dist.get_world_size())

# TODO: load ckpt

logger.info(f"Training for {args.epochs} epochs...")
for epoch in range(args.epochs):
# sampler.set_epoch(epoch)
Expand All @@ -199,16 +293,23 @@ def main(args):
y = torch.randint(0, 1000, (batch,)).to(device)
# x = x.to(device)
# y = y.to(device)

# VAE encode
with torch.no_grad():
# Map input images to latent space + normalize latents:
x = vae.encode(x).latent_dist.sample().mul_(0.18215)
x = x.to(dtype)

# Diffusion
t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
model_kwargs = dict(y=y)
loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
loss = loss_dict["loss"].mean()
opt.zero_grad()
loss.backward()
opt.step()
booster.backward(loss=loss, optimizer=optimizer)
optimizer.step()
optimizer.zero_grad()

# Update EMA
update_ema(ema, model.module)

# Log loss values:
Expand Down Expand Up @@ -238,7 +339,7 @@ def main(args):
checkpoint = {
"model": model.module.state_dict(),
"ema": ema.state_dict(),
"opt": opt.state_dict(),
"opt": optimizer.state_dict(),
"args": args,
}
checkpoint_path = f"{checkpoint_dir}/{train_steps:07d}.pt"
Expand All @@ -257,6 +358,9 @@ def main(args):
# Default args here will train DiT-XL/2 with the hyperparameters we used in our paper (except training iters).
parser = argparse.ArgumentParser()
parser.add_argument("--data-path", type=str, required=True)
parser.add_argument(
"--plugin", type=str, default="zero2", choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"]
)
parser.add_argument("--results-dir", type=str, default="results")
parser.add_argument("--model", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
Expand All @@ -268,5 +372,8 @@ def main(args):
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--log-every", type=int, default=100)
parser.add_argument("--ckpt-every", type=int, default=50_000)
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16"])
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--lr", type=float, default=1e-4, help="Gradient clipping value")
args = parser.parse_args()
main(args)

0 comments on commit 29335b8

Please sign in to comment.