-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
665ecea
commit 0bda114
Showing
4 changed files
with
296 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,11 +3,17 @@ uuid = "ae93bb3c-0630-4e85-811e-b7d91ad6ecbd" | |
authors = ["Kenta Sato <[email protected]> and contributors"] | ||
version = "1.0.0-DEV" | ||
|
||
[deps] | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
|
||
[compat] | ||
julia = "1.6" | ||
julia = "1.9" | ||
|
||
[extras] | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" | ||
|
||
[targets] | ||
test = ["Test"] | ||
test = ["Test", "Rotations"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,50 @@ | ||
# MessagePassingIPA | ||
|
||
[![Build Status](https://github.com/bicycle1885/MessagePassingIPA.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/bicycle1885/MessagePassingIPA.jl/actions/workflows/CI.yml?query=branch%3Amain) | ||
|
||
|
||
This package introduces an Invariant Point Attention (IPA) layer coupled with | ||
graph-based message passing, tailored to process structured data by effectively | ||
leveraging both geometric and topological information for superior | ||
representation learning. The operations within this package are designed to | ||
support automatic differentiation and GPU acceleration, ensuring optimal | ||
performance. | ||
|
||
For a deeper understanding, you may refer to [the AlphaFold2 | ||
paper](https://doi.org/10.1038/s41586-021-03819-2), particularly Algorithm 22 in | ||
the supplementary material. | ||
|
||
Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure | ||
prediction with AlphaFold. Nature 596, 583–589 (2021). | ||
|
||
|
||
# Usage | ||
|
||
```julia | ||
# Load packages | ||
using MessagePassingIPA: RigidTransformation, InvariantPointAttention, rigid_from_3points | ||
using GraphNeuralNetworks: rand_graph | ||
|
||
# Initialize an IPA layer | ||
n_dims_s = 32 # the dimension of single representations | ||
n_dims_z = 16 # the diemnsion of pair representations | ||
ipa = InvariantPointAttention(n_dims_s, n_dims_z) | ||
|
||
# Generate a random graph and node/edge features | ||
n_nodes = 100 | ||
n_edges = 500 | ||
g = rand_graph(n_nodes, n_edges) | ||
s = randn(Float32, n_dims_s, n_nodes) | ||
z = randn(Float32, n_dims_z, n_edges) | ||
|
||
# Generate random atom coordinates | ||
p = randn(Float32, 3, n_nodes) * 100 # centroid | ||
x1 = p .+ randn(Float32, 3, n_nodes) # N atoms | ||
x2 = p .+ randn(Float32, 3, n_nodes) # CA atoms | ||
x3 = p .+ randn(Float32, 3, n_nodes) # C atoms | ||
rigid = RigidTransformation(rigid_from_3points(x1, x2, x3)...) | ||
|
||
# Apply the IPA layer | ||
out = ipa(g, s, z, rigid) | ||
@assert size(out) == (n_dims_s, n_nodes) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,204 @@ | ||
module MessagePassingIPA | ||
|
||
# Write your package code here. | ||
using Flux: Flux, Dense, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus | ||
using GraphNeuralNetworks: GNNGraph, apply_edges, softmax_edge_neighbors, aggregate_neighbors | ||
using LinearAlgebra: normalize | ||
|
||
# Algorithm 21 (x1: N, x2: Ca, x3: C) | ||
function rigid_from_3points(x1::AbstractVector, x2::AbstractVector, x3::AbstractVector) | ||
v1 = x3 - x2 | ||
v2 = x1 - x2 | ||
e1 = normalize(v1) | ||
u2 = v2 - e1 * (e1'v2) | ||
e2 = normalize(u2) | ||
e3 = e1 × e2 | ||
R = [e1 e2 e3] | ||
t = x2 | ||
return R, t | ||
end | ||
|
||
function rigid_from_3points(x1::AbstractMatrix, x2::AbstractMatrix, x3::AbstractMatrix) | ||
v1 = x3 .- x2 | ||
v2 = x1 .- x2 | ||
e1 = v1 ./ sqrt.(sum(abs2, v1, dims=1)) | ||
u2 = v2 .- e1 .* sum(e1 .* v2, dims=1) | ||
e2 = u2 ./ sqrt.(sum(abs2, u2, dims=1)) | ||
e3 = similar(e1) | ||
e3[1, :] = e1[2, :] .* e2[3, :] .- e1[3, :] .* e2[2, :] | ||
e3[2, :] = e1[3, :] .* e2[1, :] .- e1[1, :] .* e2[3, :] | ||
e3[3, :] = e1[1, :] .* e2[2, :] .- e1[2, :] .* e2[1, :] | ||
R = similar(e1, (3, 3, size(e1, 2))) | ||
R[:, 1, :] = e1 | ||
R[:, 2, :] = e2 | ||
R[:, 3, :] = e3 | ||
t = x2 | ||
return R, t | ||
end | ||
|
||
|
||
# Rigid transformation | ||
# -------------------- | ||
|
||
# N: number of residues | ||
struct RigidTransformation{T,A<:AbstractArray{T,3},B<:AbstractArray{T,2}} | ||
rotations::A # (3, 3, N) | ||
translations::B # (3, N) | ||
end | ||
|
||
""" | ||
RigidTransformation(rotations, translations) | ||
Create a sequence of rigid transformations. | ||
# Arguments | ||
- `rotations`: 3×3xN array, `rotations[:,:,j]` represents a single rotation | ||
- `translations`: 3×N array, `translations[:,j]` represents a single translation | ||
""" | ||
RigidTransformation | ||
|
||
Flux.@functor RigidTransformation | ||
|
||
# x: (3, ?, N) | ||
""" | ||
transform(rigid::RigidTransformation, x::AbstractArray) | ||
Apply transformation `rigid` to `x`. | ||
""" | ||
transform(rigid::RigidTransformation{T}, x::AbstractArray{T,3}) where {T} = | ||
batched_mul(rigid.rotations, x) .+ unsqueeze(rigid.translations, dims=2) | ||
|
||
# y: (3, ?, N) | ||
""" | ||
inverse_transform(rigid::RigidTransformation, y::AbstractArray) | ||
Apply inverse transformation `rigid` to `y`. | ||
""" | ||
inverse_transform(rigid::RigidTransformation{T}, y::AbstractArray{T,3}) where {T} = | ||
batched_mul( | ||
batched_transpose(rigid.rotations), | ||
y .- unsqueeze(rigid.translations, dims=2), | ||
) | ||
|
||
|
||
# Invariant point attention | ||
# ------------------------- | ||
|
||
struct InvariantPointAttention | ||
# hyperparameters | ||
n_heads::Int | ||
c::Int | ||
n_query_points::Int | ||
n_point_values::Int | ||
|
||
# trainable layers and weights | ||
map_nodes::Dense | ||
map_points::Dense | ||
map_pairs::Dense | ||
map_final::Dense | ||
header_weights_raw::Any | ||
end | ||
|
||
Flux.@functor InvariantPointAttention | ||
|
||
""" | ||
InvariantPointAttention( | ||
n_dims_s, n_dims_z; | ||
n_heads = 12, | ||
c = 16, | ||
n_query_points = 4, | ||
n_point_values = 8) | ||
Create an invariant point attention layer. | ||
""" | ||
function InvariantPointAttention( | ||
n_dims_s::Integer, | ||
n_dims_z::Integer; | ||
n_heads::Integer=12, | ||
c::Integer=16, | ||
n_query_points::Integer=4, | ||
n_point_values::Integer=8 | ||
) | ||
# initialize layer weights so that outputs have std = 1 (as assumed in | ||
# AlphaFold2) if inputs follow the standard normal distribution | ||
init = Flux.kaiming_uniform(gain=1.0) | ||
map_nodes = Dense(n_dims_s => n_heads * c * 3, bias=false; init) | ||
map_points = Dense( | ||
n_dims_s => n_heads * (n_query_points * 2 + n_point_values) * 3, | ||
bias=false; | ||
init | ||
) | ||
map_pairs = Dense(n_dims_z => n_heads, bias=false; init) | ||
map_final = | ||
Dense(n_heads * (n_dims_z + c + n_point_values * (3 + 1)) => n_dims_s, bias=true) | ||
header_weights_raw = @. log(expm1($(ones(Float32, n_heads)))) # initialized so that initial weights are ones | ||
return InvariantPointAttention( | ||
n_heads, | ||
c, | ||
n_query_points, | ||
n_point_values, | ||
map_nodes, | ||
map_points, | ||
map_pairs, | ||
map_final, | ||
header_weights_raw, | ||
) | ||
end | ||
|
||
# Algorithm 22 | ||
function (ipa::InvariantPointAttention)( | ||
g::GNNGraph, | ||
s::AbstractMatrix, | ||
z::AbstractMatrix, | ||
rigid::RigidTransformation, | ||
) | ||
F = eltype(s) | ||
n_residues = size(s, 2) | ||
(; n_heads, c, n_query_points, n_point_values) = ipa | ||
|
||
# map inputs (residues come at the last dimension) | ||
nodes = reshape(ipa.map_nodes(s), n_heads, :, n_residues) | ||
points = transform(rigid, reshape(ipa.map_points(s), 3, :, n_residues)) | ||
bias = ipa.map_pairs(z) | ||
|
||
# split into queries, keys and values | ||
nodes_q, nodes_k, nodes_v = chunk(nodes, size=[c, c, c], dims=2) | ||
points_q, points_k, points_v = chunk( | ||
points, | ||
size=n_heads * [n_query_points, n_query_points, n_point_values], | ||
dims=2, | ||
) | ||
points_q = reshape(points_q, 3, n_heads, :, n_residues) | ||
points_k = reshape(points_k, 3, n_heads, :, n_residues) | ||
points_v = reshape(points_v, 3, n_heads, :, n_residues) | ||
|
||
# run message passing | ||
w_C = F(√(2 / 9n_query_points)) | ||
w_L = F(1 / √3) | ||
γ = softplus.(ipa.header_weights_raw) | ||
function message(xi, xj, e) | ||
u = sumdrop(xi.nodes_q .* xj.nodes_k, dims=2) # inner products | ||
v = sumdrop(abs2.(xi.points_q .- xj.points_k), dims=(1, 3)) # sum of squared distances | ||
attn_logits = @. w_L * (1 / √$(F(c)) * u + e - γ * w_C / 2 * v) # logits of attention scores | ||
return (; attn_logits, nodes_v=xj.nodes_v, points_v=xj.points_v) | ||
end | ||
xi = xj = (; nodes_q, nodes_k, nodes_v, points_q, points_k, points_v) | ||
e = bias | ||
msgs = apply_edges(message, g; xi, xj, e) | ||
|
||
# aggregate messages from neighbors | ||
attn = softmax_edge_neighbors(g, msgs.attn_logits) # (heads, edges) | ||
out_pairs = | ||
aggregate_neighbors(g, +, reshape(attn, n_heads, 1, :) .* unsqueeze(z, dims=1)) | ||
out_nodes = aggregate_neighbors(g, +, reshape(attn, n_heads, 1, :) .* msgs.nodes_v) | ||
out_points = aggregate_neighbors(g, +, reshape(attn, 1, n_heads, 1, :) .* msgs.points_v) | ||
out_points = inverse_transform(rigid, reshape(out_points, 3, :, n_residues)) | ||
out_points_norm = sqrt.(sumdrop(abs2.(out_points), dims=1)) | ||
|
||
# return the final output | ||
out = vcat(flatten.((out_pairs, out_nodes, out_points, out_points_norm))...) | ||
return ipa.map_final(out) | ||
end | ||
|
||
sumdrop(x; dims) = dropdims(sum(x; dims); dims) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,46 @@ | ||
using MessagePassingIPA | ||
using GraphNeuralNetworks: rand_graph | ||
using Rotations: RotMatrix | ||
using Test | ||
|
||
@testset "MessagePassingIPA.jl" begin | ||
# Write your tests here. | ||
@testset "RigidTransformation" begin | ||
n = 100 | ||
rotations = stack(rand(RotMatrix{3,Float32}) for _ in 1:n) | ||
translations = randn(Float32, 3, n) | ||
rigid = MessagePassingIPA.RigidTransformation(rotations, translations) | ||
x = randn(Float32, 3, 12, n) | ||
y = MessagePassingIPA.transform(rigid, x) | ||
@test size(x) == size(y) | ||
@test x ≈ MessagePassingIPA.inverse_transform(rigid, y) | ||
end | ||
|
||
@testset "InvariantPointAttention" begin | ||
n_dims_s = 32 | ||
n_dims_z = 16 | ||
ipa = MessagePassingIPA.InvariantPointAttention(n_dims_s, n_dims_z) | ||
|
||
n_nodes = 100 | ||
n_edges = 500 | ||
g = rand_graph(n_nodes, n_edges) | ||
s = randn(Float32, n_dims_s, n_nodes) | ||
z = randn(Float32, n_dims_z, n_edges) | ||
|
||
# check returned type and size | ||
c = randn(Float32, 3, n_nodes) * 1000 | ||
x1 = c .+ randn(Float32, 3, n_nodes) | ||
x2 = c .+ randn(Float32, 3, n_nodes) | ||
x3 = c .+ randn(Float32, 3, n_nodes) | ||
rigid1 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...) | ||
@test ipa(g, s, z, rigid1) isa Matrix{Float32} | ||
@test size(ipa(g, s, z, rigid1)) == (n_dims_s, n_nodes) | ||
|
||
# check invariance | ||
R, t = rand(RotMatrix{3,Float32}), randn(Float32, 3) | ||
x1 = R * x1 .+ t | ||
x2 = R * x2 .+ t | ||
x3 = R * x3 .+ t | ||
rigid2 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...) | ||
@test ipa(g, s, z, rigid1) ≈ ipa(g, s, z, rigid2) | ||
end | ||
end |