Skip to content

Commit

Permalink
Various transformer updates to improve performance (#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Nov 26, 2024
1 parent cb14024 commit b7e5029
Show file tree
Hide file tree
Showing 5 changed files with 522 additions and 221 deletions.
114 changes: 68 additions & 46 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class LogDiffusionImages(Callback):
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
Default: ``1138``.
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
use_mask (bool): Whether or not to use the mask for the encoded text. Default: ``True``.
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
t5_latent_key: (str): key to use for the T5 latents in the batch. Default: ``'T5_LATENTS'``.
Expand All @@ -56,6 +57,7 @@ def __init__(self,
rescaled_guidance: Optional[float] = None,
seed: Optional[int] = 1138,
use_table: bool = False,
use_mask: bool = True,
t5_encoder: Optional[str] = None,
clip_encoder: Optional[str] = None,
t5_latent_key: str = 'T5_LATENTS',
Expand All @@ -71,6 +73,7 @@ def __init__(self,
self.rescaled_guidance = rescaled_guidance
self.seed = seed
self.use_table = use_table
self.use_mask = use_mask
self.t5_latent_key = t5_latent_key
self.t5_mask_key = t5_mask_key
self.clip_latent_key = clip_latent_key
Expand Down Expand Up @@ -100,47 +103,47 @@ def __init__(self,
local_files_only=True)

t5_model = AutoModel.from_pretrained(t5_encoder,
torch_dtype=torch.float16,
torch_dtype=torch.bfloat16,
cache_dir=self.cache_dir,
local_files_only=True).encoder.cuda().eval()
clip_model = CLIPTextModel.from_pretrained(clip_encoder,
subfolder='text_encoder',
torch_dtype=torch.float16,
torch_dtype=torch.bfloat16,
cache_dir=self.cache_dir,
local_files_only=True).cuda().eval()

for batch in self.batched_prompts:
latent_batch = {}
tokenized_t5 = t5_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
t5_ids = tokenized_t5['input_ids'].cuda()
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)

tokenized_clip = clip_tokenizer(batch,
with torch.no_grad():
for batch in self.batched_prompts:
latent_batch = {}
tokenized_t5 = t5_tokenizer(batch,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
max_length=t5_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
clip_ids = tokenized_clip['input_ids'].cuda()
clip_outputs = clip_model(input_ids=clip_ids,
attention_mask=clip_attention_mask,
output_hidden_states=True)
clip_latents = clip_outputs.hidden_states[-2].cpu()
clip_pooled = clip_outputs[1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch[self.t5_latent_key] = t5_latents
latent_batch[self.t5_mask_key] = t5_attention_mask
latent_batch[self.clip_latent_key] = clip_latents
latent_batch[self.clip_mask_key] = clip_attention_mask
latent_batch[self.clip_pooled_key] = clip_pooled
self.batched_latents.append(latent_batch)
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
t5_ids = tokenized_t5['input_ids'].cuda()
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)

tokenized_clip = clip_tokenizer(batch,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
clip_ids = tokenized_clip['input_ids'].cuda()
clip_outputs = clip_model(input_ids=clip_ids,
attention_mask=clip_attention_mask,
output_hidden_states=True)
clip_latents = clip_outputs.hidden_states[-2].cpu()
clip_pooled = clip_outputs[1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch[self.t5_latent_key] = t5_latents
latent_batch[self.t5_mask_key] = t5_attention_mask
latent_batch[self.clip_latent_key] = clip_latents
latent_batch[self.clip_mask_key] = clip_attention_mask
latent_batch[self.clip_pooled_key] = clip_pooled
self.batched_latents.append(latent_batch)

del t5_model
del clip_model
Expand All @@ -160,21 +163,40 @@ def eval_start(self, state: State, logger: Logger):
if self.precomputed_latents:
for batch in self.batched_latents:
pooled_prompt = batch[self.clip_pooled_key].cuda()
prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
batch[self.clip_latent_key].cuda(),
batch[self.t5_mask_key].cuda(),
batch[self.clip_mask_key].cuda())
gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
if self.use_mask:
prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
batch[self.clip_latent_key].cuda(),
batch[self.t5_mask_key].cuda(),
batch[self.clip_mask_key].cuda())
gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
else:
prompt_embeds = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
batch[self.clip_latent_key].cuda())
gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
# Clear up GPU tensors
del pooled_prompt
del prompt_embeds
if self.use_mask:
del prompt_mask
torch.cuda.empty_cache()
else:
for batch in self.batched_prompts:
gen_images = model.generate(
Expand Down
91 changes: 91 additions & 0 deletions diffusion/datasets/synthetic_image_caption_latents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Synthetic Image-Caption dataset."""

from typing import Dict, Optional

import torch
from composer.utils import dist
from torch.utils.data import DataLoader, Dataset


class SyntheticImageCaptionLatentsDataset(Dataset):
"""Synthetic dataset imitating a dataset containing image-caption pairs.
Args:
image_size (int): Size of the synthetic images. Default: ``512``.
clip_length (int): Length of the synthetic clip embeddings. Default: ``77``.
clip_dim (int): Dimension of the synthetic clip embeddings. Default: ``768``.
t5_length (int): Length of the synthetic T5 embeddings. Default: ``512``.
t5_dim (int): Dimension of the synthetic T5 embeddings. Default: ``4096``.
"""

def __init__(self,
image_size: int = 512,
clip_length: int = 77,
clip_dim: int = 768,
t5_length: int = 512,
t5_dim: int = 4096):

super().__init__()
self.image_size = image_size
self.clip_length = clip_length
self.clip_dim = clip_dim
self.t5_length = t5_length
self.t5_dim = t5_dim

def __len__(self):
return 100_000

def __getitem__(self, idx):
out = {}
out['cond_crops_coords_top_left'] = torch.tensor([0, 0], dtype=torch.float)
out['cond_original_size'] = torch.tensor([self.image_size, self.image_size], dtype=torch.float)
out['cond_target_size'] = torch.tensor([self.image_size, self.image_size], dtype=torch.float)
out['image'] = torch.randn(3, self.image_size, self.image_size)
out['CLIP_LATENTS'] = torch.randn(self.clip_length, self.clip_dim, dtype=torch.float)
out['CLIP_POOLED'] = torch.randn(self.clip_dim, dtype=torch.float)
out['CLIP_ATTENTION_MASK'] = torch.ones(self.clip_length)
out['T5_LATENTS'] = torch.randn(self.t5_length, self.t5_dim, dtype=torch.float)
out['T5_ATTENTION_MASK'] = torch.ones(self.t5_length)
return out


def build_synthetic_image_caption_latents_dataloader(
batch_size: int,
image_size: int = 512,
clip_length: int = 77,
clip_dim: int = 768,
t5_length: int = 512,
t5_dim: int = 4096,
dataloader_kwargs: Optional[Dict] = None,
):
"""Builds a dataloader for the synthetic image-caption dataset.
Args:
batch_size (int): Batch size for the dataloader.
image_size (int): Size of the synthetic images. Default: ``512``.
clip_length (int): Length of the synthetic clip embeddings. Default: ``77``.
clip_dim (int): Dimension of the synthetic clip embeddings. Default: ``768``.
t5_length (int): Length of the synthetic T5 embeddings. Default: ``512``.
t5_dim (int): Dimension of the synthetic T5 embeddings. Default: ``4096``.
dataloader_kwargs (optional, dict): Additional arguments to pass to the dataloader. Default ``None``.
"""
if dataloader_kwargs is None:
dataloader_kwargs = {}

dataset = SyntheticImageCaptionLatentsDataset(image_size=image_size,
clip_length=clip_length,
clip_dim=clip_dim,
t5_length=t5_length,
t5_dim=t5_dim)

dataloader = DataLoader(
dataset=dataset,
sampler=dist.get_sampler(dataset),
batch_size=batch_size,
**dataloader_kwargs,
)

return dataloader
Loading

0 comments on commit b7e5029

Please sign in to comment.