Skip to content

Commit

Permalink
Merge branch 'lucidrains:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcusLoppe authored Jun 20, 2024
2 parents 14126a2 + 36d8dba commit 14a4470
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 2 deletions.
11 changes: 10 additions & 1 deletion meshgpt_pytorch/meshgpt_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,7 @@ def __init__(
cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out text condition
dropout = 0.,
coarse_pre_gateloop_depth = 2,
coarse_post_gateloop_depth = 0,
coarse_adaptive_rmsnorm = False,
fine_pre_gateloop_depth = 2,
gateloop_use_heinsen = False,
Expand Down Expand Up @@ -1179,6 +1180,8 @@ def __init__(

self.coarse_gateloop_block = GateLoopBlock(dim, depth = coarse_pre_gateloop_depth, use_heinsen = gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None

self.coarse_post_gateloop_block = GateLoopBlock(dim, depth = coarse_post_gateloop_depth, use_heinsen = gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None

# main autoregressive attention network
# attending to a face token

Expand Down Expand Up @@ -1560,8 +1563,9 @@ def forward_on_codes(
coarse_cache,
fine_cache,
coarse_gateloop_cache,
coarse_post_gateloop_cache,
fine_gateloop_cache
) = cache if exists(cache) else ((None,) * 5)
) = cache if exists(cache) else ((None,) * 6)

if exists(cache):
cached_face_codes_len = cached_attended_face_codes.shape[-2]
Expand Down Expand Up @@ -1597,6 +1601,10 @@ def forward_on_codes(
return_hiddens = True,
**attn_context_kwargs
)

if exists(self.coarse_post_gateloop_block):
face_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(face_codes, cache = coarse_post_gateloop_cache)

else:
attended_face_codes = None

Expand Down Expand Up @@ -1717,6 +1725,7 @@ def forward_on_codes(
coarse_cache,
fine_cache,
coarse_gateloop_cache,
coarse_post_gateloop_cache,
fine_gateloop_cache
)

Expand Down
2 changes: 1 addition & 1 deletion meshgpt_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.4.3'
__version__ = '1.5.1'
1 change: 1 addition & 0 deletions tests/test_meshgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_readme(adaptive_rmsnorm):
fine_cross_attend_text = True,
text_cond_with_film = False,
condition_on_text = True,
coarse_post_gateloop_depth = 1,
coarse_adaptive_rmsnorm = adaptive_rmsnorm
)

Expand Down

0 comments on commit 14a4470

Please sign in to comment.