Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Various transformer updates to improve performance #182

Merged
merged 8 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 68 additions & 46 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class LogDiffusionImages(Callback):
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
Default: ``1138``.
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
use_mask (bool): Whether or not to use the mask for the encoded text. Default: ``True``.
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
t5_latent_key: (str): key to use for the T5 latents in the batch. Default: ``'T5_LATENTS'``.
Expand All @@ -56,6 +57,7 @@ def __init__(self,
rescaled_guidance: Optional[float] = None,
seed: Optional[int] = 1138,
use_table: bool = False,
use_mask: bool = True,
t5_encoder: Optional[str] = None,
clip_encoder: Optional[str] = None,
t5_latent_key: str = 'T5_LATENTS',
Expand All @@ -71,6 +73,7 @@ def __init__(self,
self.rescaled_guidance = rescaled_guidance
self.seed = seed
self.use_table = use_table
self.use_mask = use_mask
self.t5_latent_key = t5_latent_key
self.t5_mask_key = t5_mask_key
self.clip_latent_key = clip_latent_key
Expand Down Expand Up @@ -100,47 +103,47 @@ def __init__(self,
local_files_only=True)

t5_model = AutoModel.from_pretrained(t5_encoder,
torch_dtype=torch.float16,
torch_dtype=torch.bfloat16,
cache_dir=self.cache_dir,
local_files_only=True).encoder.cuda().eval()
clip_model = CLIPTextModel.from_pretrained(clip_encoder,
subfolder='text_encoder',
torch_dtype=torch.float16,
torch_dtype=torch.bfloat16,
cache_dir=self.cache_dir,
local_files_only=True).cuda().eval()

for batch in self.batched_prompts:
latent_batch = {}
tokenized_t5 = t5_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
t5_ids = tokenized_t5['input_ids'].cuda()
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)

tokenized_clip = clip_tokenizer(batch,
with torch.no_grad():
for batch in self.batched_prompts:
latent_batch = {}
tokenized_t5 = t5_tokenizer(batch,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
max_length=t5_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
clip_ids = tokenized_clip['input_ids'].cuda()
clip_outputs = clip_model(input_ids=clip_ids,
attention_mask=clip_attention_mask,
output_hidden_states=True)
clip_latents = clip_outputs.hidden_states[-2].cpu()
clip_pooled = clip_outputs[1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch[self.t5_latent_key] = t5_latents
latent_batch[self.t5_mask_key] = t5_attention_mask
latent_batch[self.clip_latent_key] = clip_latents
latent_batch[self.clip_mask_key] = clip_attention_mask
latent_batch[self.clip_pooled_key] = clip_pooled
self.batched_latents.append(latent_batch)
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
t5_ids = tokenized_t5['input_ids'].cuda()
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)

tokenized_clip = clip_tokenizer(batch,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
clip_ids = tokenized_clip['input_ids'].cuda()
clip_outputs = clip_model(input_ids=clip_ids,
attention_mask=clip_attention_mask,
output_hidden_states=True)
clip_latents = clip_outputs.hidden_states[-2].cpu()
clip_pooled = clip_outputs[1].cpu()
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)

latent_batch[self.t5_latent_key] = t5_latents
latent_batch[self.t5_mask_key] = t5_attention_mask
latent_batch[self.clip_latent_key] = clip_latents
latent_batch[self.clip_mask_key] = clip_attention_mask
latent_batch[self.clip_pooled_key] = clip_pooled
self.batched_latents.append(latent_batch)

del t5_model
del clip_model
Expand All @@ -160,21 +163,40 @@ def eval_start(self, state: State, logger: Logger):
if self.precomputed_latents:
for batch in self.batched_latents:
pooled_prompt = batch[self.clip_pooled_key].cuda()
prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
batch[self.clip_latent_key].cuda(),
batch[self.t5_mask_key].cuda(),
batch[self.clip_mask_key].cuda())
gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
if self.use_mask:
prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
batch[self.clip_latent_key].cuda(),
batch[self.t5_mask_key].cuda(),
batch[self.clip_mask_key].cuda())
gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
prompt_mask=prompt_mask,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
else:
prompt_embeds = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
batch[self.clip_latent_key].cuda())
gen_images = model.generate(prompt_embeds=prompt_embeds,
pooled_prompt=pooled_prompt,
height=self.size[0],
width=self.size[1],
guidance_scale=self.guidance_scale,
rescaled_guidance=self.rescaled_guidance,
progress_bar=False,
num_inference_steps=self.num_inference_steps,
seed=self.seed)
all_gen_images.append(gen_images)
# Clear up GPU tensors
del pooled_prompt
del prompt_embeds
if self.use_mask:
del prompt_mask
torch.cuda.empty_cache()
else:
for batch in self.batched_prompts:
gen_images = model.generate(
Expand Down
91 changes: 91 additions & 0 deletions diffusion/datasets/synthetic_image_caption_latents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2022 MosaicML Diffusion authors
# SPDX-License-Identifier: Apache-2.0

"""Synthetic Image-Caption dataset."""

from typing import Dict, Optional

import torch
from composer.utils import dist
from torch.utils.data import DataLoader, Dataset


class SyntheticImageCaptionLatentsDataset(Dataset):
"""Synthetic dataset imitating a dataset containing image-caption pairs.

Args:
image_size (int): Size of the synthetic images. Default: ``512``.
clip_length (int): Length of the synthetic clip embeddings. Default: ``77``.
clip_dim (int): Dimension of the synthetic clip embeddings. Default: ``768``.
t5_length (int): Length of the synthetic T5 embeddings. Default: ``512``.
t5_dim (int): Dimension of the synthetic T5 embeddings. Default: ``4096``.
"""

def __init__(self,
image_size: int = 512,
clip_length: int = 77,
clip_dim: int = 768,
t5_length: int = 512,
t5_dim: int = 4096):

super().__init__()
self.image_size = image_size
self.clip_length = clip_length
self.clip_dim = clip_dim
self.t5_length = t5_length
self.t5_dim = t5_dim

def __len__(self):
return 100_000

def __getitem__(self, idx):
out = {}
out['cond_crops_coords_top_left'] = torch.tensor([0, 0], dtype=torch.float)
out['cond_original_size'] = torch.tensor([self.image_size, self.image_size], dtype=torch.float)
out['cond_target_size'] = torch.tensor([self.image_size, self.image_size], dtype=torch.float)
out['image'] = torch.randn(3, self.image_size, self.image_size)
out['CLIP_LATENTS'] = torch.randn(self.clip_length, self.clip_dim, dtype=torch.float)
out['CLIP_POOLED'] = torch.randn(self.clip_dim, dtype=torch.float)
out['CLIP_ATTENTION_MASK'] = torch.ones(self.clip_length)
out['T5_LATENTS'] = torch.randn(self.t5_length, self.t5_dim, dtype=torch.float)
out['T5_ATTENTION_MASK'] = torch.ones(self.t5_length)
return out


def build_synthetic_image_caption_latents_dataloader(
batch_size: int,
image_size: int = 512,
clip_length: int = 77,
clip_dim: int = 768,
t5_length: int = 512,
t5_dim: int = 4096,
dataloader_kwargs: Optional[Dict] = None,
):
"""Builds a dataloader for the synthetic image-caption dataset.

Args:
batch_size (int): Batch size for the dataloader.
image_size (int): Size of the synthetic images. Default: ``512``.
clip_length (int): Length of the synthetic clip embeddings. Default: ``77``.
clip_dim (int): Dimension of the synthetic clip embeddings. Default: ``768``.
t5_length (int): Length of the synthetic T5 embeddings. Default: ``512``.
t5_dim (int): Dimension of the synthetic T5 embeddings. Default: ``4096``.
dataloader_kwargs (optional, dict): Additional arguments to pass to the dataloader. Default ``None``.
"""
if dataloader_kwargs is None:
dataloader_kwargs = {}

dataset = SyntheticImageCaptionLatentsDataset(image_size=image_size,
clip_length=clip_length,
clip_dim=clip_dim,
t5_length=t5_length,
t5_dim=t5_dim)

dataloader = DataLoader(
dataset=dataset,
sampler=dist.get_sampler(dataset),
batch_size=batch_size,
**dataloader_kwargs,
)

return dataloader
Loading