Skip to content

Commit

Permalink
fix a few autocast warnings, add new technique for cfg
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 7, 2024
1 parent c166739 commit 192f8b9
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 9 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -947,3 +947,12 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
note = {under review}
}
```

```bibtex
@inproceedings{Sadat2024EliminatingOA,
title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models},
author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273098845}
}
```
10 changes: 7 additions & 3 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn.functional as F
from torch import nn
from torch.cuda.amp import autocast
from torch.amp import autocast
from torch.nn.parallel import DistributedDataParallel
import torchvision.transforms as T

Expand Down Expand Up @@ -565,6 +565,8 @@ def sample(
video_frames = None,
batch_size = 1,
cond_scale = 1.,
cfg_remove_parallel_component = True,
cfg_keep_parallel_frac = 0.,
lowres_sample_noise_level = None,
start_at_unet_number = 1,
start_image_or_video = None,
Expand All @@ -583,7 +585,7 @@ def sample(
if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'

with autocast(enabled = False):
with autocast('cuda', enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)

text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
Expand Down Expand Up @@ -724,6 +726,8 @@ def sample(
sigma_min = unet_sigma_min,
sigma_max = unet_sigma_max,
cond_scale = unet_cond_scale,
remove_parallel_component = cfg_remove_parallel_component,
keep_parallel_frac = cfg_keep_parallel_frac,
lowres_cond_img = lowres_cond_img,
lowres_noise_times = lowres_noise_times,
dynamic_threshold = dynamic_threshold,
Expand Down Expand Up @@ -811,7 +815,7 @@ def forward(
assert all([*map(len, texts)]), 'text cannot be empty'
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'

with autocast(enabled = False):
with autocast('cuda', enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)

text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
Expand Down
57 changes: 53 additions & 4 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from torch import nn, einsum
from torch.cuda.amp import autocast
from torch.amp import autocast
from torch.special import expm1
import torchvision.transforms as T

Expand Down Expand Up @@ -187,6 +187,15 @@ def safe_get_tuple_index(tup, index, default = None):
return default
return tup[index]

def pack_one_with_inverse(x, pattern):
packed, packed_shape = pack([x], pattern)

def inverse(x, inverse_pattern = None):
inverse_pattern = default(inverse_pattern, pattern)
return unpack(x, packed_shape, inverse_pattern)[0]

return packed, inverse

# image normalization functions
# ddpms expect images to be in the range of -1 to 1

Expand All @@ -206,6 +215,21 @@ def prob_mask_like(shape, prob, device):
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

# for improved cfg, getting parallel and orthogonal components of cfg update

def project(x, y):
x, inverse = pack_one_with_inverse(x, 'b *')
y, _ = pack_one_with_inverse(y, 'b *')

dtype = x.dtype
x, y = x.double(), y.double()
unit = F.normalize(y, dim = -1)

parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
orthogonal = x - parallel

return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)

# gaussian diffusion with continuous time helper functions and classes
# large part of this was thanks to @crowsonkb at https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/utils.py

Expand Down Expand Up @@ -1511,6 +1535,8 @@ def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
remove_parallel_component = True,
keep_parallel_frac = 0.,
**kwargs
):
logits = self.forward(*args, **kwargs)
Expand All @@ -1519,7 +1545,14 @@ def forward_with_cond_scale(
return logits

null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale

update = (logits - null_logits)

if remove_parallel_component:
parallel, orthogonal = project(update, logits)
update = orthogonal + parallel * keep_parallel_frac

return logits + update * (cond_scale - 1)

def forward(
self,
Expand Down Expand Up @@ -2055,6 +2088,8 @@ def p_mean_variance(
self_cond = None,
lowres_noise_times = None,
cond_scale = 1.,
cfg_remove_parallel_component = True,
cfg_keep_parallel_frac = 0.,
model_output = None,
t_next = None,
pred_objective = 'noise',
Expand All @@ -2076,6 +2111,8 @@ def p_mean_variance(
text_mask = text_mask,
cond_images = cond_images,
cond_scale = cond_scale,
remove_parallel_component = cfg_remove_parallel_component,
keep_parallel_frac = cfg_keep_parallel_frac,
lowres_cond_img = lowres_cond_img,
self_cond = self_cond,
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_noise_times),
Expand Down Expand Up @@ -2124,6 +2161,8 @@ def p_sample(
cond_video_frames = None,
post_cond_video_frames = None,
cond_scale = 1.,
cfg_remove_parallel_component = True,
cfg_keep_parallel_frac = 0.,
self_cond = None,
lowres_cond_img = None,
lowres_noise_times = None,
Expand All @@ -2149,6 +2188,8 @@ def p_sample(
text_mask = text_mask,
cond_images = cond_images,
cond_scale = cond_scale,
cfg_remove_parallel_component = cfg_remove_parallel_component,
cfg_keep_parallel_frac = cfg_keep_parallel_frac,
lowres_cond_img = lowres_cond_img,
self_cond = self_cond,
lowres_noise_times = lowres_noise_times,
Expand Down Expand Up @@ -2185,6 +2226,8 @@ def p_sample_loop(
init_images = None,
skip_steps = None,
cond_scale = 1,
cfg_remove_parallel_component = False,
cfg_keep_parallel_frac = 0.,
pred_objective = 'noise',
dynamic_threshold = True,
use_tqdm = True
Expand Down Expand Up @@ -2260,6 +2303,8 @@ def p_sample_loop(
text_mask = text_mask,
cond_images = cond_images,
cond_scale = cond_scale,
cfg_remove_parallel_component = cfg_remove_parallel_component,
cfg_keep_parallel_frac = cfg_keep_parallel_frac,
self_cond = self_cond,
lowres_cond_img = lowres_cond_img,
lowres_noise_times = lowres_noise_times,
Expand Down Expand Up @@ -2308,6 +2353,8 @@ def sample(
skip_steps = None,
batch_size = 1,
cond_scale = 1.,
cfg_remove_parallel_component = True,
cfg_keep_parallel_frac = 0.,
lowres_sample_noise_level = None,
start_at_unet_number = 1,
start_image_or_video = None,
Expand All @@ -2326,7 +2373,7 @@ def sample(
if exists(texts) and not exists(text_embeds) and not self.unconditional:
assert all([*map(len, texts)]), 'text cannot be empty'

with autocast(enabled = False):
with autocast('cuda', enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)

text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
Expand Down Expand Up @@ -2469,6 +2516,8 @@ def sample(
init_images = unet_init_images,
skip_steps = unet_skip_steps,
cond_scale = unet_cond_scale,
cfg_remove_parallel_component = cfg_remove_parallel_component,
cfg_keep_parallel_frac = cfg_keep_parallel_frac,
lowres_cond_img = lowres_cond_img,
lowres_noise_times = lowres_noise_times,
noise_scheduler = noise_scheduler,
Expand Down Expand Up @@ -2695,7 +2744,7 @@ def forward(
assert all([*map(len, texts)]), 'text cannot be empty'
assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'

with autocast(enabled = False):
with autocast('cuda', enabled = False):
text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)

text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
Expand Down
33 changes: 32 additions & 1 deletion imagen_pytorch/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,15 @@ def pad_tuple_to_length(t, length, fillvalue = None):
return t
return (*t, *((fillvalue,) * remain_length))

def pack_one_with_inverse(x, pattern):
packed, packed_shape = pack([x], pattern)

def inverse(x, inverse_pattern = None):
inverse_pattern = default(inverse_pattern, pattern)
return unpack(x, packed_shape, inverse_pattern)[0]

return packed, inverse

# helper classes

class Identity(nn.Module):
Expand Down Expand Up @@ -131,6 +140,19 @@ def masked_mean(t, *, dim, mask = None):

return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)

def project(x, y):
x, inverse = pack_one_with_inverse(x, 'b *')
y, _ = pack_one_with_inverse(y, 'b *')

dtype = x.dtype
x, y = x.double(), y.double()
unit = F.normalize(y, dim = -1)

parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
orthogonal = x - parallel

return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)

def resize_video_to(
video,
target_image_size,
Expand Down Expand Up @@ -1637,6 +1659,8 @@ def forward_with_cond_scale(
self,
*args,
cond_scale = 1.,
remove_parallel_component = False,
keep_parallel_frac = 0.,
**kwargs
):
logits = self.forward(*args, **kwargs)
Expand All @@ -1645,7 +1669,14 @@ def forward_with_cond_scale(
return logits

null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale

update = (logits - null_logits)

if remove_parallel_component:
parallel, orthogonal = project(update, logits)
update = orthogonal + parallel * keep_parallel_frac

return logits + update * (cond_scale - 1)

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.0.0'
__version__ = '2.1.0'

0 comments on commit 192f8b9

Please sign in to comment.