Skip to content

Commit

Permalink
readd ability to condition fine transformer with text, more efficient…
Browse files Browse the repository at this point in the history
… now with cross attention key/value caching
  • Loading branch information
lucidrains committed Jun 1, 2024
1 parent cfe23c6 commit 8d7032d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
35 changes: 33 additions & 2 deletions meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,6 +1039,7 @@ def __init__(
fine_attn_depth = 2,
fine_attn_dim_head = 32,
fine_attn_heads = 8,
fine_cross_attend_text = False,
pad_id = -1,
num_sos_tokens = None,
condition_on_text = False,
Expand Down Expand Up @@ -1137,6 +1138,8 @@ def __init__(

# decoding the vertices, 2-stage hierarchy

self.fine_cross_attend_text = condition_on_text and fine_cross_attend_text

self.fine_decoder = Decoder(
dim = dim_fine,
depth = fine_attn_depth,
Expand All @@ -1145,6 +1148,9 @@ def __init__(
attn_flash = flash_attn,
attn_dropout = dropout,
ff_dropout = dropout,
cross_attend = self.fine_cross_attend_text,
cross_attn_dim_context = cross_attn_dim_context,
cross_attn_num_mem_kv = cross_attn_num_mem_kv,
**attn_kwargs
)

Expand Down Expand Up @@ -1512,8 +1518,17 @@ def forward_on_codes(
if exists(fine_cache):
for attn_intermediate in fine_cache.attn_intermediates:
ck, cv = attn_intermediate.cached_kv
ck, cv = map(lambda t: rearrange(t, '(b nf) ... -> b nf ...', b = batch), (ck, cv))
ck, cv = map(lambda t: t[:, -1, :, :curr_vertex_pos], (ck, cv))
ck, cv = [rearrange(t, '(b nf) ... -> b nf ...', b = batch) for t in (ck, cv)]

# when operating on the cached key / values, treat self attention and cross attention differently

layer_type = attn_intermediate.layer_type

if layer_type == 'a':
ck, cv = [t[:, -1, :, :curr_vertex_pos] for t in (ck, cv)]
elif layer_type == 'c':
ck, cv = [t[:, -1, ...] for t in (ck, cv)]

attn_intermediate.cached_kv = (ck, cv)

num_faces = fine_vertex_codes.shape[1]
Expand All @@ -1524,9 +1539,25 @@ def forward_on_codes(
if one_face:
fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)]

# handle maybe cross attention conditioning of fine transformer with text

fine_attn_context_kwargs = dict()

if self.fine_cross_attend_text:
repeat_batch = fine_vertex_codes.shape[0] // text_embed.shape[0]

text_embed = repeat(text_embed, 'b ... -> (b r) ...' , r = repeat_batch)
text_mask = repeat(text_mask, 'b ... -> (b r) ...', r = repeat_batch)

fine_attn_context_kwargs = dict(
context = text_embed,
context_mask = text_mask
)

attended_vertex_codes, fine_cache = self.fine_decoder(
fine_vertex_codes,
cache = fine_cache,
**fine_attn_context_kwargs,
return_hiddens = True
)

Expand Down
2 changes: 1 addition & 1 deletion meshgpt_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.2.9'
__version__ = '1.2.10'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'torchtyping',
'tqdm',
'vector-quantize-pytorch>=1.14.22',
'x-transformers>=1.30.4',
'x-transformers>=1.30.6',
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit 8d7032d

Please sign in to comment.