From 9db2e343c0c6a515b5cf73390826d8297a7cff96 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 20 Jun 2024 07:33:38 -0700 Subject: [PATCH 1/2] add ability for another gateloop block of any depth after the coarse transformer --- meshgpt_pytorch/meshgpt_pytorch.py | 11 ++++++++++- meshgpt_pytorch/version.py | 2 +- tests/test_meshgpt.py | 1 + 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index 2485160..4902756 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -1098,6 +1098,7 @@ def __init__( cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out text condition dropout = 0., coarse_pre_gateloop_depth = 2, + coarse_post_gateloop_depth = 0, coarse_adaptive_rmsnorm = False, fine_pre_gateloop_depth = 2, gateloop_use_heinsen = False, @@ -1179,6 +1180,8 @@ def __init__( self.coarse_gateloop_block = GateLoopBlock(dim, depth = coarse_pre_gateloop_depth, use_heinsen = gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None + self.coarse_post_gateloop_block = GateLoopBlock(dim, depth = coarse_post_gateloop_depth, use_heinsen = gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None + # main autoregressive attention network # attending to a face token @@ -1560,8 +1563,9 @@ def forward_on_codes( coarse_cache, fine_cache, coarse_gateloop_cache, + coarse_post_gateloop_cache, fine_gateloop_cache - ) = cache if exists(cache) else ((None,) * 5) + ) = cache if exists(cache) else ((None,) * 6) if exists(cache): cached_face_codes_len = cached_attended_face_codes.shape[-2] @@ -1597,6 +1601,10 @@ def forward_on_codes( return_hiddens = True, **attn_context_kwargs ) + + if exists(self.coarse_post_gateloop_block): + face_codes, coarse_post_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_post_gateloop_cache) + else: attended_face_codes = None @@ -1717,6 +1725,7 @@ def forward_on_codes( coarse_cache, fine_cache, coarse_gateloop_cache, + coarse_post_gateloop_cache, fine_gateloop_cache ) diff --git a/meshgpt_pytorch/version.py b/meshgpt_pytorch/version.py index 4e7c72a..77f1c8e 100644 --- a/meshgpt_pytorch/version.py +++ b/meshgpt_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.4.3' +__version__ = '1.5.0' diff --git a/tests/test_meshgpt.py b/tests/test_meshgpt.py index 1912410..bd5d077 100644 --- a/tests/test_meshgpt.py +++ b/tests/test_meshgpt.py @@ -38,6 +38,7 @@ def test_readme(adaptive_rmsnorm): fine_cross_attend_text = True, text_cond_with_film = False, condition_on_text = True, + coarse_post_gateloop_depth = 1, coarse_adaptive_rmsnorm = adaptive_rmsnorm ) From 36d8dbac9f16cbe4542b6c1681fbbffb874daab5 Mon Sep 17 00:00:00 2001 From: lucidrains Date: Thu, 20 Jun 2024 12:02:29 -0700 Subject: [PATCH 2/2] fix post coarse transformer gateloop layername --- meshgpt_pytorch/meshgpt_pytorch.py | 2 +- meshgpt_pytorch/version.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/meshgpt_pytorch/meshgpt_pytorch.py b/meshgpt_pytorch/meshgpt_pytorch.py index 4902756..2963eba 100644 --- a/meshgpt_pytorch/meshgpt_pytorch.py +++ b/meshgpt_pytorch/meshgpt_pytorch.py @@ -1603,7 +1603,7 @@ def forward_on_codes( ) if exists(self.coarse_post_gateloop_block): - face_codes, coarse_post_gateloop_cache = self.coarse_gateloop_block(face_codes, cache = coarse_post_gateloop_cache) + face_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(face_codes, cache = coarse_post_gateloop_cache) else: attended_face_codes = None diff --git a/meshgpt_pytorch/version.py b/meshgpt_pytorch/version.py index 77f1c8e..51ed7c4 100644 --- a/meshgpt_pytorch/version.py +++ b/meshgpt_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.5.0' +__version__ = '1.5.1'