From 8d7032d38fa38a3e227c681f8a70968dbbf393fc Mon Sep 17 00:00:00 2001 From: lucidrains Date: Fri, 31 May 2024 17:00:55 -0700 Subject: [PATCH] readd ability to condition fine transformer with text, more efficient now with cross attention key/value caching --- meshgpt_pytorch/meshgpt_pytorch.py | 35 ++++++++++++++++++++++++++++-- meshgpt_pytorch/version.py | 2 +- setup.py | 2 +- 3 files changed, 35 insertions(+), 4 deletions(-) diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index 5c413c50..d8e5ca8c 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -1039,6 +1039,7 @@ def __init__( fine_attn_depth = 2, fine_attn_dim_head = 32, fine_attn_heads = 8, + fine_cross_attend_text = False, pad_id = -1, num_sos_tokens = None, condition_on_text = False, @@ -1137,6 +1138,8 @@ def __init__( # decoding the vertices, 2-stage hierarchy + self.fine_cross_attend_text = condition_on_text and fine_cross_attend_text + self.fine_decoder = Decoder( dim = dim_fine, depth = fine_attn_depth, @@ -1145,6 +1148,9 @@ def __init__( attn_flash = flash_attn, attn_dropout = dropout, ff_dropout = dropout, + cross_attend = self.fine_cross_attend_text, + cross_attn_dim_context = cross_attn_dim_context, + cross_attn_num_mem_kv = cross_attn_num_mem_kv, **attn_kwargs ) @@ -1512,8 +1518,17 @@ def forward_on_codes( if exists(fine_cache): for attn_intermediate in fine_cache.attn_intermediates: ck, cv = attn_intermediate.cached_kv - ck, cv = map(lambda t: rearrange(t, '(b nf) ... -> b nf ...', b = batch), (ck, cv)) - ck, cv = map(lambda t: t[:, -1, :, :curr_vertex_pos], (ck, cv)) + ck, cv = [rearrange(t, '(b nf) ... -> b nf ...', b = batch) for t in (ck, cv)] + + # when operating on the cached key / values, treat self attention and cross attention differently + + layer_type = attn_intermediate.layer_type + + if layer_type == 'a': + ck, cv = [t[:, -1, :, :curr_vertex_pos] for t in (ck, cv)] + elif layer_type == 'c': + ck, cv = [t[:, -1, ...] for t in (ck, cv)] + attn_intermediate.cached_kv = (ck, cv) num_faces = fine_vertex_codes.shape[1] @@ -1524,9 +1539,25 @@ def forward_on_codes( if one_face: fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)] + # handle maybe cross attention conditioning of fine transformer with text + + fine_attn_context_kwargs = dict() + + if self.fine_cross_attend_text: + repeat_batch = fine_vertex_codes.shape[0] // text_embed.shape[0] + + text_embed = repeat(text_embed, 'b ... -> (b r) ...' , r = repeat_batch) + text_mask = repeat(text_mask, 'b ... -> (b r) ...', r = repeat_batch) + + fine_attn_context_kwargs = dict( + context = text_embed, + context_mask = text_mask + ) + attended_vertex_codes, fine_cache = self.fine_decoder( fine_vertex_codes, cache = fine_cache, + **fine_attn_context_kwargs, return_hiddens = True ) diff --git a/meshgpt_pytorch/version.py b/meshgpt_pytorch/version.py index 5089ceda..c7ac77f9 100644 --- a/meshgpt_pytorch/version.py +++ b/meshgpt_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.2.9' +__version__ = '1.2.10' diff --git a/setup.py b/setup.py index 8579bd08..a342256f 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ 'torchtyping', 'tqdm', 'vector-quantize-pytorch>=1.14.22', - 'x-transformers>=1.30.4', + 'x-transformers>=1.30.6', ], classifiers=[ 'Development Status :: 4 - Beta',