Skip to content

Commit

Permalink
edge values coming from HTR needs to be subjected to value residual a…
Browse files Browse the repository at this point in the history
…s well
  • Loading branch information
lucidrains committed Dec 20, 2024
1 parent ca83f39 commit 504c2bb
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
18 changes: 13 additions & 5 deletions gotennet_pytorch/gotennet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import einx
from einx import get_at

from einops import repeat, reduce
from einops import rearrange, repeat, reduce
from einops.layers.torch import Rearrange

from e3nn.o3 import spherical_harmonics
Expand Down Expand Up @@ -467,18 +467,25 @@ def forward(

# value residual mixing

next_value_residuals = (values, post_attn_values)
next_value_residuals = (values, post_attn_values, edge_values)

if exists(self.to_value_residual_mix):
assert exists(value_residuals)

value_residual, post_attn_values_residual = value_residuals
value_residual, post_attn_values_residual, edge_values_residual = value_residuals

mix = self.to_value_residual_mix(hi)

values = values.lerp(value_residual, mix)
post_attn_values = post_attn_values.lerp(post_attn_values_residual, mix)

if exists(neighbor_indices):
mix = get_at('b h [n] ..., b i j -> b h i j ...', mix, neighbor_indices)
else:
mix = rearrange(mix, 'b h j ... -> b h 1 j ...')

edge_values = edge_values.lerp(edge_values_residual, mix)

# account for neighbor logic

if exists(neighbor_indices):
Expand Down Expand Up @@ -592,7 +599,8 @@ def __init__(
ff_kwargs: dict = dict(),
return_coors = True,
proj_invariant_dim = None,
final_norm = True
final_norm = True,
add_value_residual = True
):
super().__init__()
self.accept_embed = accept_embed
Expand Down Expand Up @@ -631,7 +639,7 @@ def __init__(

self.layers.append(ModuleList([
HierarchicalTensorRefinement(dim, dim_edge_refinement, max_degree),
GeometryAwareTensorAttention(dim, max_degree, dim_head, heads, mlp_expansion_factor, learned_value_residual_mix = not is_first),
GeometryAwareTensorAttention(dim, max_degree, dim_head, heads, mlp_expansion_factor, learned_value_residual_mix = add_value_residual and not is_first),
EquivariantFeedForward(dim, max_degree, mlp_expansion_factor),
]))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "gotennet-pytorch"
version = "0.1.1"
version = "0.1.2"
description = "GotenNet in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit 504c2bb

Please sign in to comment.