Skip to content

Commit

Permalink
cleanup some bad code and allow for customizing more attention layers…
Browse files Browse the repository at this point in the history
… for the middle of the unet
  • Loading branch information
lucidrains committed Nov 15, 2022
1 parent 103d65f commit a81f7f6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
32 changes: 19 additions & 13 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@

import kornia.augmentation as K

from einops import rearrange, repeat, reduce
from einops import rearrange, repeat, reduce, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom

from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

Expand Down Expand Up @@ -682,14 +681,10 @@ def __init__(
if exists(cond_dim):
attn_klass = CrossAttention if not linear_attn else LinearCrossAttention

self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
attn_klass(
dim = dim_out,
context_dim = cond_dim,
**attn_kwargs
)
self.cross_attn = attn_klass(
dim = dim_out,
context_dim = cond_dim,
**attn_kwargs
)

self.block1 = Block(dim, dim_out, groups = groups)
Expand All @@ -712,7 +707,11 @@ def forward(self, x, time_emb = None, cond = None):

if exists(self.cross_attn):
assert exists(cond)
h = rearrange(h, 'b c h w -> b h w c')
h, ps = pack([h], 'b * c')
h = self.cross_attn(h, context = cond) + h
h, = unpack(h, ps, 'b * c')
h = rearrange(h, 'b h w c -> b c h w')

h = self.block2(h, scale_shift = scale_shift)

Expand Down Expand Up @@ -970,14 +969,20 @@ def __init__(

for _ in range(depth):
self.layers.append(nn.ModuleList([
EinopsToAndFrom('b c h w', 'b (h w) c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)),
ChanFeedForward(dim = dim, mult = ff_mult)
Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn),
FeedForward(dim = dim, mult = ff_mult)
]))

def forward(self, x, context = None):
x = rearrange(x, 'b c h w -> b h w c')
x, ps = pack([x], 'b * c')

for attn, ff in self.layers:
x = attn(x, context = context) + x
x = ff(x) + x

x, = unpack(x, ps, 'b * c')
x = rearrange(x, 'b h w c -> b c h w')
return x

class LinearAttentionTransformerBlock(nn.Module):
Expand Down Expand Up @@ -1091,6 +1096,7 @@ def __init__(
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
layer_attns = True,
layer_attns_depth = 1,
layer_mid_attns_depth = 1,
layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
layer_cross_attns = True,
Expand Down Expand Up @@ -1339,7 +1345,7 @@ def __init__(
mid_dim = dims[-1]

self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_attn = TransformerBlock(mid_dim, depth = layer_mid_attns_depth, **attn_kwargs) if attend_at_middle else None
self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])

# upsample klass
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.15.1'
__version__ = '1.16.0'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
'accelerate',
'click',
'datasets',
'einops>=0.4',
'einops>=0.6',
'einops-exts',
'ema-pytorch>=0.0.3',
'fsspec',
Expand Down

0 comments on commit a81f7f6

Please sign in to comment.