Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bicycle1885 committed Nov 4, 2023
1 parent 665ecea commit 0bda114
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 4 deletions.
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ uuid = "ae93bb3c-0630-4e85-811e-b7d91ad6ecbd"
authors = ["Kenta Sato <[email protected]> and contributors"]
version = "1.0.0-DEV"

[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
julia = "1.6"
julia = "1.9"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc"

[targets]
test = ["Test"]
test = ["Test", "Rotations"]
47 changes: 47 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,50 @@
# MessagePassingIPA

[![Build Status](https://github.com/bicycle1885/MessagePassingIPA.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/bicycle1885/MessagePassingIPA.jl/actions/workflows/CI.yml?query=branch%3Amain)


This package introduces an Invariant Point Attention (IPA) layer coupled with
graph-based message passing, tailored to process structured data by effectively
leveraging both geometric and topological information for superior
representation learning. The operations within this package are designed to
support automatic differentiation and GPU acceleration, ensuring optimal
performance.

For a deeper understanding, you may refer to [the AlphaFold2
paper](https://doi.org/10.1038/s41586-021-03819-2), particularly Algorithm 22 in
the supplementary material.

Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure
prediction with AlphaFold. Nature 596, 583–589 (2021).


# Usage

```julia
# Load packages
using MessagePassingIPA: RigidTransformation, InvariantPointAttention, rigid_from_3points
using GraphNeuralNetworks: rand_graph

# Initialize an IPA layer
n_dims_s = 32 # the dimension of single representations
n_dims_z = 16 # the diemnsion of pair representations
ipa = InvariantPointAttention(n_dims_s, n_dims_z)

# Generate a random graph and node/edge features
n_nodes = 100
n_edges = 500
g = rand_graph(n_nodes, n_edges)
s = randn(Float32, n_dims_s, n_nodes)
z = randn(Float32, n_dims_z, n_edges)

# Generate random atom coordinates
p = randn(Float32, 3, n_nodes) * 100 # centroid
x1 = p .+ randn(Float32, 3, n_nodes) # N atoms
x2 = p .+ randn(Float32, 3, n_nodes) # CA atoms
x3 = p .+ randn(Float32, 3, n_nodes) # C atoms
rigid = RigidTransformation(rigid_from_3points(x1, x2, x3)...)

# Apply the IPA layer
out = ipa(g, s, z, rigid)
@assert size(out) == (n_dims_s, n_nodes)
```
201 changes: 200 additions & 1 deletion src/MessagePassingIPA.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,204 @@
module MessagePassingIPA

# Write your package code here.
using Flux: Flux, Dense, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus
using GraphNeuralNetworks: GNNGraph, apply_edges, softmax_edge_neighbors, aggregate_neighbors
using LinearAlgebra: normalize

# Algorithm 21 (x1: N, x2: Ca, x3: C)
function rigid_from_3points(x1::AbstractVector, x2::AbstractVector, x3::AbstractVector)
v1 = x3 - x2
v2 = x1 - x2
e1 = normalize(v1)
u2 = v2 - e1 * (e1'v2)
e2 = normalize(u2)
e3 = e1 × e2
R = [e1 e2 e3]
t = x2
return R, t
end

function rigid_from_3points(x1::AbstractMatrix, x2::AbstractMatrix, x3::AbstractMatrix)
v1 = x3 .- x2
v2 = x1 .- x2
e1 = v1 ./ sqrt.(sum(abs2, v1, dims=1))
u2 = v2 .- e1 .* sum(e1 .* v2, dims=1)
e2 = u2 ./ sqrt.(sum(abs2, u2, dims=1))
e3 = similar(e1)
e3[1, :] = e1[2, :] .* e2[3, :] .- e1[3, :] .* e2[2, :]
e3[2, :] = e1[3, :] .* e2[1, :] .- e1[1, :] .* e2[3, :]
e3[3, :] = e1[1, :] .* e2[2, :] .- e1[2, :] .* e2[1, :]
R = similar(e1, (3, 3, size(e1, 2)))
R[:, 1, :] = e1
R[:, 2, :] = e2
R[:, 3, :] = e3
t = x2
return R, t
end


# Rigid transformation
# --------------------

# N: number of residues
struct RigidTransformation{T,A<:AbstractArray{T,3},B<:AbstractArray{T,2}}
rotations::A # (3, 3, N)
translations::B # (3, N)
end

"""
RigidTransformation(rotations, translations)
Create a sequence of rigid transformations.
# Arguments
- `rotations`: 3×3xN array, `rotations[:,:,j]` represents a single rotation
- `translations`: 3×N array, `translations[:,j]` represents a single translation
"""
RigidTransformation

Flux.@functor RigidTransformation

# x: (3, ?, N)
"""
transform(rigid::RigidTransformation, x::AbstractArray)
Apply transformation `rigid` to `x`.
"""
transform(rigid::RigidTransformation{T}, x::AbstractArray{T,3}) where {T} =
batched_mul(rigid.rotations, x) .+ unsqueeze(rigid.translations, dims=2)

# y: (3, ?, N)
"""
inverse_transform(rigid::RigidTransformation, y::AbstractArray)
Apply inverse transformation `rigid` to `y`.
"""
inverse_transform(rigid::RigidTransformation{T}, y::AbstractArray{T,3}) where {T} =
batched_mul(
batched_transpose(rigid.rotations),
y .- unsqueeze(rigid.translations, dims=2),
)


# Invariant point attention
# -------------------------

struct InvariantPointAttention
# hyperparameters
n_heads::Int
c::Int
n_query_points::Int
n_point_values::Int

# trainable layers and weights
map_nodes::Dense
map_points::Dense
map_pairs::Dense
map_final::Dense
header_weights_raw::Any
end

Flux.@functor InvariantPointAttention

"""
InvariantPointAttention(
n_dims_s, n_dims_z;
n_heads = 12,
c = 16,
n_query_points = 4,
n_point_values = 8)
Create an invariant point attention layer.
"""
function InvariantPointAttention(
n_dims_s::Integer,
n_dims_z::Integer;
n_heads::Integer=12,
c::Integer=16,
n_query_points::Integer=4,
n_point_values::Integer=8
)
# initialize layer weights so that outputs have std = 1 (as assumed in
# AlphaFold2) if inputs follow the standard normal distribution
init = Flux.kaiming_uniform(gain=1.0)
map_nodes = Dense(n_dims_s => n_heads * c * 3, bias=false; init)
map_points = Dense(
n_dims_s => n_heads * (n_query_points * 2 + n_point_values) * 3,
bias=false;
init
)
map_pairs = Dense(n_dims_z => n_heads, bias=false; init)
map_final =
Dense(n_heads * (n_dims_z + c + n_point_values * (3 + 1)) => n_dims_s, bias=true)
header_weights_raw = @. log(expm1($(ones(Float32, n_heads)))) # initialized so that initial weights are ones
return InvariantPointAttention(
n_heads,
c,
n_query_points,
n_point_values,
map_nodes,
map_points,
map_pairs,
map_final,
header_weights_raw,
)
end

# Algorithm 22
function (ipa::InvariantPointAttention)(
g::GNNGraph,
s::AbstractMatrix,
z::AbstractMatrix,
rigid::RigidTransformation,
)
F = eltype(s)
n_residues = size(s, 2)
(; n_heads, c, n_query_points, n_point_values) = ipa

# map inputs (residues come at the last dimension)
nodes = reshape(ipa.map_nodes(s), n_heads, :, n_residues)
points = transform(rigid, reshape(ipa.map_points(s), 3, :, n_residues))
bias = ipa.map_pairs(z)

# split into queries, keys and values
nodes_q, nodes_k, nodes_v = chunk(nodes, size=[c, c, c], dims=2)
points_q, points_k, points_v = chunk(
points,
size=n_heads * [n_query_points, n_query_points, n_point_values],
dims=2,
)
points_q = reshape(points_q, 3, n_heads, :, n_residues)
points_k = reshape(points_k, 3, n_heads, :, n_residues)
points_v = reshape(points_v, 3, n_heads, :, n_residues)

# run message passing
w_C = F((2 / 9n_query_points))
w_L = F(1 / 3)
γ = softplus.(ipa.header_weights_raw)
function message(xi, xj, e)
u = sumdrop(xi.nodes_q .* xj.nodes_k, dims=2) # inner products
v = sumdrop(abs2.(xi.points_q .- xj.points_k), dims=(1, 3)) # sum of squared distances
attn_logits = @. w_L * (1 / $(F(c)) * u + e - γ * w_C / 2 * v) # logits of attention scores
return (; attn_logits, nodes_v=xj.nodes_v, points_v=xj.points_v)
end
xi = xj = (; nodes_q, nodes_k, nodes_v, points_q, points_k, points_v)
e = bias
msgs = apply_edges(message, g; xi, xj, e)

# aggregate messages from neighbors
attn = softmax_edge_neighbors(g, msgs.attn_logits) # (heads, edges)
out_pairs =
aggregate_neighbors(g, +, reshape(attn, n_heads, 1, :) .* unsqueeze(z, dims=1))
out_nodes = aggregate_neighbors(g, +, reshape(attn, n_heads, 1, :) .* msgs.nodes_v)
out_points = aggregate_neighbors(g, +, reshape(attn, 1, n_heads, 1, :) .* msgs.points_v)
out_points = inverse_transform(rigid, reshape(out_points, 3, :, n_residues))
out_points_norm = sqrt.(sumdrop(abs2.(out_points), dims=1))

# return the final output
out = vcat(flatten.((out_pairs, out_nodes, out_points, out_points_norm))...)
return ipa.map_final(out)
end

sumdrop(x; dims) = dropdims(sum(x; dims); dims)

end
42 changes: 41 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
using MessagePassingIPA
using GraphNeuralNetworks: rand_graph
using Rotations: RotMatrix
using Test

@testset "MessagePassingIPA.jl" begin
# Write your tests here.
@testset "RigidTransformation" begin
n = 100
rotations = stack(rand(RotMatrix{3,Float32}) for _ in 1:n)
translations = randn(Float32, 3, n)
rigid = MessagePassingIPA.RigidTransformation(rotations, translations)
x = randn(Float32, 3, 12, n)
y = MessagePassingIPA.transform(rigid, x)
@test size(x) == size(y)
@test x MessagePassingIPA.inverse_transform(rigid, y)
end

@testset "InvariantPointAttention" begin
n_dims_s = 32
n_dims_z = 16
ipa = MessagePassingIPA.InvariantPointAttention(n_dims_s, n_dims_z)

n_nodes = 100
n_edges = 500
g = rand_graph(n_nodes, n_edges)
s = randn(Float32, n_dims_s, n_nodes)
z = randn(Float32, n_dims_z, n_edges)

# check returned type and size
c = randn(Float32, 3, n_nodes) * 1000
x1 = c .+ randn(Float32, 3, n_nodes)
x2 = c .+ randn(Float32, 3, n_nodes)
x3 = c .+ randn(Float32, 3, n_nodes)
rigid1 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...)
@test ipa(g, s, z, rigid1) isa Matrix{Float32}
@test size(ipa(g, s, z, rigid1)) == (n_dims_s, n_nodes)

# check invariance
R, t = rand(RotMatrix{3,Float32}), randn(Float32, 3)
x1 = R * x1 .+ t
x2 = R * x2 .+ t
x3 = R * x3 .+ t
rigid2 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...)
@test ipa(g, s, z, rigid1) ipa(g, s, z, rigid2)
end
end

0 comments on commit 0bda114

Please sign in to comment.