Skip to content

Commit

Permalink
add hyper connections
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 29, 2024
1 parent ece5da2 commit 7657954
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 17 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,14 @@ invariant, coors_out = model(atom_ids, adj_mat = adj_mat, coors = coors, lens =
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
journal = {ArXiv},
year = {2024},
volume = {abs/2409.19606},
url = {https://api.semanticscholar.org/CorpusID:272987528}
}
```
54 changes: 41 additions & 13 deletions gotennet_pytorch/gotennet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

from gotennet_pytorch.tensor_typing import Float, Int, Bool

from hyper_connections import get_init_and_expand_reduce_stream_functions

# ein notation

# b - batch
Expand Down Expand Up @@ -265,12 +267,9 @@ def forward(

x_residuals.append(modulated_one_degree)

# handle residuals within the module

h = h + h_residual
x = [*map(sum, zip(x, x_residuals))]
# return residuals

return h, x
return h_residual, x_residuals

# hierarchical tensor refinement
# section 3.4
Expand Down Expand Up @@ -567,8 +566,8 @@ def forward(

# modulate with invariant scales and sum residuals

h_with_residual = h + reduce(h_scales, 'b i j 1 d -> b i d', 'sum')
x_with_residual = []
h_residual = reduce(h_scales, 'b i j 1 d -> b i d', 'sum')
x_residuals = []

for one_degree, one_r_ij, one_degree_scale, one_r_ij_scale in zip(x, r_ij, x_scales.unbind(dim = -2), r_ij_scales.unbind(dim = -2)):

Expand All @@ -582,9 +581,9 @@ def forward(
else:
x_ij_residual = einsum('b j d m, b i j d -> b i d m', one_degree, one_degree_scale)

x_with_residual.append(r_ij_residual + x_ij_residual)
x_residuals.append(r_ij_residual + x_ij_residual)

out = (h_with_residual, x_with_residual)
out = (h_residual, x_residuals)

if not return_value_residual:
return out
Expand Down Expand Up @@ -612,7 +611,8 @@ def __init__(
return_coors = True,
proj_invariant_dim = None,
final_norm = True,
add_value_residual = True
add_value_residual = True,
num_residual_streams = 4
):
super().__init__()
self.accept_embed = accept_embed
Expand All @@ -622,6 +622,10 @@ def __init__(

dim_edge_refinement = default(dim_edge_refinement, dim)

# hyper connections, applied to invariant h for starters

init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

# only consider neighbors less than `cutoff_radius`, in paper, they used ~ 5 angstroms
# can further randomly select from eligible neighbors with `max_neighbors`

Expand All @@ -645,6 +649,7 @@ def __init__(
# layers, thus deep learning

self.layers = ModuleList([])
self.residual_fns = ModuleList([])

for layer_index in range(depth):
is_first = layer_index == 0
Expand All @@ -655,6 +660,11 @@ def __init__(
EquivariantFeedForward(dim, max_degree, mlp_expansion_factor),
]))

self.residual_fns.append(ModuleList([
init_hyper_conn(dim = dim),
init_hyper_conn(dim = dim),
]))

# not mentioned in paper, but transformers need a final norm

self.final_norm = final_norm
Expand Down Expand Up @@ -754,25 +764,43 @@ def forward(

value_residuals = None

# maybe expand invariant h residual stream

h = self.expand_streams(h)

# go through the layers

for htr, attn, ff in self.layers:
for (htr, attn, ff), (h_attn_residual_fn, h_ff_residual_fn) in zip(self.layers, self.residual_fns):

# hierarchical tensor refinement

t_ij = htr(t_ij, x, neighbor_indices = neighbor_indices)

# followed by attention, but of course

(h, x), next_value_residuals = attn(h, t_ij, r_ij, x, mask = mask, neighbor_indices = neighbor_indices, neighbor_mask = neighbor_mask, value_residuals = value_residuals, return_value_residual = True)
h, add_attn_residual = h_attn_residual_fn(h)

(h_residual, x_residuals), next_value_residuals = attn(h, t_ij, r_ij, x, mask = mask, neighbor_indices = neighbor_indices, neighbor_mask = neighbor_mask, value_residuals = value_residuals, return_value_residual = True)

# add attention residuals

h = add_attn_residual(h_residual)
x = [*map(sum, zip(x, x_residuals))]

# handle value residual

value_residuals = default(value_residuals, next_value_residuals)

# feedforward

h, x = ff(h, x)
h, add_ff_residual = h_ff_residual_fn(h)

h_residual, x_residuals = ff(h, x)

# add feedforward residuals

h = add_ff_residual(h_residual)
x = [*map(sum, zip(x, x_residuals))]

# maybe final norms

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "gotennet-pytorch"
version = "0.1.4"
version = "0.2.0"
description = "GotenNet in Pytorch"
authors = [
{ name = "Phil Wang", email = "[email protected]" }
Expand Down Expand Up @@ -29,6 +29,7 @@ dependencies = [
'einx>=0.3.0',
'einops>=0.8.0',
'jaxtyping',
'hyper-connections>=0.1.0',
'torch>=2.4',
]

Expand Down
11 changes: 8 additions & 3 deletions tests/test_gotennet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import torch
from torch import sin, cos, stack
from einops import rearrange
Expand Down Expand Up @@ -67,7 +68,8 @@ def test_invariant():
assert torch.allclose(inv1, inv2, atol = 1e-5)

@torch_default_dtype(torch.float64)
def test_equivariant():
@pytest.mark.parametrize('num_residual_streams', (1, 4))
def test_equivariant(num_residual_streams):

model = GotenNet(
dim = 256,
Expand All @@ -79,7 +81,8 @@ def test_equivariant():
return_coors = True,
ff_kwargs = dict(
layernorm_input = True
)
),
num_residual_streams = num_residual_streams
)

random_rotation = rot(*torch.randn(3))
Expand Down Expand Up @@ -121,7 +124,8 @@ def test_equivariant_with_atom_feats():
assert torch.allclose(coors1 @ random_rotation, coors2, atol = 1e-5)

@torch_default_dtype(torch.float64)
def test_equivariant_neighbors():
@pytest.mark.parametrize('num_residual_streams', (1, 4))
def test_equivariant_neighbors(num_residual_streams):

model = GotenNet(
dim = 256,
Expand All @@ -132,6 +136,7 @@ def test_equivariant_neighbors():
cutoff_radius = 5.,
dim_edge_refinement = 256,
return_coors = True,
num_residual_streams = num_residual_streams,
ff_kwargs = dict(
layernorm_input = True
)
Expand Down

0 comments on commit 7657954

Please sign in to comment.