From 0bda114fbccfc2f28a5b6904167305d296db3d30 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Sat, 4 Nov 2023 13:52:51 +0100 Subject: [PATCH] initial commit --- Project.toml | 10 +- README.md | 47 +++++++++ src/MessagePassingIPA.jl | 201 ++++++++++++++++++++++++++++++++++++++- test/runtests.jl | 42 +++++++- 4 files changed, 296 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 90bc4aa..4737ee5 100644 --- a/Project.toml +++ b/Project.toml @@ -3,11 +3,17 @@ uuid = "ae93bb3c-0630-4e85-811e-b7d91ad6ecbd" authors = ["Kenta Sato and contributors"] version = "1.0.0-DEV" +[deps] +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + [compat] -julia = "1.6" +julia = "1.9" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Rotations = "6038ab10-8711-5258-84ad-4b1120ba62dc" [targets] -test = ["Test"] +test = ["Test", "Rotations"] diff --git a/README.md b/README.md index 49e2080..5b6d0c4 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,50 @@ # MessagePassingIPA [![Build Status](https://github.com/bicycle1885/MessagePassingIPA.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/bicycle1885/MessagePassingIPA.jl/actions/workflows/CI.yml?query=branch%3Amain) + + +This package introduces an Invariant Point Attention (IPA) layer coupled with +graph-based message passing, tailored to process structured data by effectively +leveraging both geometric and topological information for superior +representation learning. The operations within this package are designed to +support automatic differentiation and GPU acceleration, ensuring optimal +performance. + +For a deeper understanding, you may refer to [the AlphaFold2 +paper](https://doi.org/10.1038/s41586-021-03819-2), particularly Algorithm 22 in +the supplementary material. + +Jumper, J., Evans, R., Pritzel, A. et al. Highly accurate protein structure +prediction with AlphaFold. Nature 596, 583–589 (2021). + + +# Usage + +```julia +# Load packages +using MessagePassingIPA: RigidTransformation, InvariantPointAttention, rigid_from_3points +using GraphNeuralNetworks: rand_graph + +# Initialize an IPA layer +n_dims_s = 32 # the dimension of single representations +n_dims_z = 16 # the diemnsion of pair representations +ipa = InvariantPointAttention(n_dims_s, n_dims_z) + +# Generate a random graph and node/edge features +n_nodes = 100 +n_edges = 500 +g = rand_graph(n_nodes, n_edges) +s = randn(Float32, n_dims_s, n_nodes) +z = randn(Float32, n_dims_z, n_edges) + +# Generate random atom coordinates +p = randn(Float32, 3, n_nodes) * 100 # centroid +x1 = p .+ randn(Float32, 3, n_nodes) # N atoms +x2 = p .+ randn(Float32, 3, n_nodes) # CA atoms +x3 = p .+ randn(Float32, 3, n_nodes) # C atoms +rigid = RigidTransformation(rigid_from_3points(x1, x2, x3)...) + +# Apply the IPA layer +out = ipa(g, s, z, rigid) +@assert size(out) == (n_dims_s, n_nodes) +``` diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 12eaaf2..43cc629 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -1,5 +1,204 @@ module MessagePassingIPA -# Write your package code here. +using Flux: Flux, Dense, flatten, unsqueeze, chunk, batched_mul, batched_vec, batched_transpose, softplus +using GraphNeuralNetworks: GNNGraph, apply_edges, softmax_edge_neighbors, aggregate_neighbors +using LinearAlgebra: normalize + +# Algorithm 21 (x1: N, x2: Ca, x3: C) +function rigid_from_3points(x1::AbstractVector, x2::AbstractVector, x3::AbstractVector) + v1 = x3 - x2 + v2 = x1 - x2 + e1 = normalize(v1) + u2 = v2 - e1 * (e1'v2) + e2 = normalize(u2) + e3 = e1 × e2 + R = [e1 e2 e3] + t = x2 + return R, t +end + +function rigid_from_3points(x1::AbstractMatrix, x2::AbstractMatrix, x3::AbstractMatrix) + v1 = x3 .- x2 + v2 = x1 .- x2 + e1 = v1 ./ sqrt.(sum(abs2, v1, dims=1)) + u2 = v2 .- e1 .* sum(e1 .* v2, dims=1) + e2 = u2 ./ sqrt.(sum(abs2, u2, dims=1)) + e3 = similar(e1) + e3[1, :] = e1[2, :] .* e2[3, :] .- e1[3, :] .* e2[2, :] + e3[2, :] = e1[3, :] .* e2[1, :] .- e1[1, :] .* e2[3, :] + e3[3, :] = e1[1, :] .* e2[2, :] .- e1[2, :] .* e2[1, :] + R = similar(e1, (3, 3, size(e1, 2))) + R[:, 1, :] = e1 + R[:, 2, :] = e2 + R[:, 3, :] = e3 + t = x2 + return R, t +end + + +# Rigid transformation +# -------------------- + +# N: number of residues +struct RigidTransformation{T,A<:AbstractArray{T,3},B<:AbstractArray{T,2}} + rotations::A # (3, 3, N) + translations::B # (3, N) +end + +""" + RigidTransformation(rotations, translations) + +Create a sequence of rigid transformations. + +# Arguments +- `rotations`: 3×3xN array, `rotations[:,:,j]` represents a single rotation +- `translations`: 3×N array, `translations[:,j]` represents a single translation +""" +RigidTransformation + +Flux.@functor RigidTransformation + +# x: (3, ?, N) +""" + transform(rigid::RigidTransformation, x::AbstractArray) + +Apply transformation `rigid` to `x`. +""" +transform(rigid::RigidTransformation{T}, x::AbstractArray{T,3}) where {T} = + batched_mul(rigid.rotations, x) .+ unsqueeze(rigid.translations, dims=2) + +# y: (3, ?, N) +""" + inverse_transform(rigid::RigidTransformation, y::AbstractArray) + +Apply inverse transformation `rigid` to `y`. +""" +inverse_transform(rigid::RigidTransformation{T}, y::AbstractArray{T,3}) where {T} = + batched_mul( + batched_transpose(rigid.rotations), + y .- unsqueeze(rigid.translations, dims=2), + ) + + +# Invariant point attention +# ------------------------- + +struct InvariantPointAttention + # hyperparameters + n_heads::Int + c::Int + n_query_points::Int + n_point_values::Int + + # trainable layers and weights + map_nodes::Dense + map_points::Dense + map_pairs::Dense + map_final::Dense + header_weights_raw::Any +end + +Flux.@functor InvariantPointAttention + +""" + InvariantPointAttention( + n_dims_s, n_dims_z; + n_heads = 12, + c = 16, + n_query_points = 4, + n_point_values = 8) + +Create an invariant point attention layer. +""" +function InvariantPointAttention( + n_dims_s::Integer, + n_dims_z::Integer; + n_heads::Integer=12, + c::Integer=16, + n_query_points::Integer=4, + n_point_values::Integer=8 +) + # initialize layer weights so that outputs have std = 1 (as assumed in + # AlphaFold2) if inputs follow the standard normal distribution + init = Flux.kaiming_uniform(gain=1.0) + map_nodes = Dense(n_dims_s => n_heads * c * 3, bias=false; init) + map_points = Dense( + n_dims_s => n_heads * (n_query_points * 2 + n_point_values) * 3, + bias=false; + init + ) + map_pairs = Dense(n_dims_z => n_heads, bias=false; init) + map_final = + Dense(n_heads * (n_dims_z + c + n_point_values * (3 + 1)) => n_dims_s, bias=true) + header_weights_raw = @. log(expm1($(ones(Float32, n_heads)))) # initialized so that initial weights are ones + return InvariantPointAttention( + n_heads, + c, + n_query_points, + n_point_values, + map_nodes, + map_points, + map_pairs, + map_final, + header_weights_raw, + ) +end + +# Algorithm 22 +function (ipa::InvariantPointAttention)( + g::GNNGraph, + s::AbstractMatrix, + z::AbstractMatrix, + rigid::RigidTransformation, +) + F = eltype(s) + n_residues = size(s, 2) + (; n_heads, c, n_query_points, n_point_values) = ipa + + # map inputs (residues come at the last dimension) + nodes = reshape(ipa.map_nodes(s), n_heads, :, n_residues) + points = transform(rigid, reshape(ipa.map_points(s), 3, :, n_residues)) + bias = ipa.map_pairs(z) + + # split into queries, keys and values + nodes_q, nodes_k, nodes_v = chunk(nodes, size=[c, c, c], dims=2) + points_q, points_k, points_v = chunk( + points, + size=n_heads * [n_query_points, n_query_points, n_point_values], + dims=2, + ) + points_q = reshape(points_q, 3, n_heads, :, n_residues) + points_k = reshape(points_k, 3, n_heads, :, n_residues) + points_v = reshape(points_v, 3, n_heads, :, n_residues) + + # run message passing + w_C = F(√(2 / 9n_query_points)) + w_L = F(1 / √3) + γ = softplus.(ipa.header_weights_raw) + function message(xi, xj, e) + u = sumdrop(xi.nodes_q .* xj.nodes_k, dims=2) # inner products + v = sumdrop(abs2.(xi.points_q .- xj.points_k), dims=(1, 3)) # sum of squared distances + attn_logits = @. w_L * (1 / √$(F(c)) * u + e - γ * w_C / 2 * v) # logits of attention scores + return (; attn_logits, nodes_v=xj.nodes_v, points_v=xj.points_v) + end + xi = xj = (; nodes_q, nodes_k, nodes_v, points_q, points_k, points_v) + e = bias + msgs = apply_edges(message, g; xi, xj, e) + + # aggregate messages from neighbors + attn = softmax_edge_neighbors(g, msgs.attn_logits) # (heads, edges) + out_pairs = + aggregate_neighbors(g, +, reshape(attn, n_heads, 1, :) .* unsqueeze(z, dims=1)) + out_nodes = aggregate_neighbors(g, +, reshape(attn, n_heads, 1, :) .* msgs.nodes_v) + out_points = aggregate_neighbors(g, +, reshape(attn, 1, n_heads, 1, :) .* msgs.points_v) + out_points = inverse_transform(rigid, reshape(out_points, 3, :, n_residues)) + out_points_norm = sqrt.(sumdrop(abs2.(out_points), dims=1)) + + # return the final output + out = vcat(flatten.((out_pairs, out_nodes, out_points, out_points_norm))...) + return ipa.map_final(out) +end + +sumdrop(x; dims) = dropdims(sum(x; dims); dims) end diff --git a/test/runtests.jl b/test/runtests.jl index 4806a26..cd391c1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,46 @@ using MessagePassingIPA +using GraphNeuralNetworks: rand_graph +using Rotations: RotMatrix using Test @testset "MessagePassingIPA.jl" begin - # Write your tests here. + @testset "RigidTransformation" begin + n = 100 + rotations = stack(rand(RotMatrix{3,Float32}) for _ in 1:n) + translations = randn(Float32, 3, n) + rigid = MessagePassingIPA.RigidTransformation(rotations, translations) + x = randn(Float32, 3, 12, n) + y = MessagePassingIPA.transform(rigid, x) + @test size(x) == size(y) + @test x ≈ MessagePassingIPA.inverse_transform(rigid, y) + end + + @testset "InvariantPointAttention" begin + n_dims_s = 32 + n_dims_z = 16 + ipa = MessagePassingIPA.InvariantPointAttention(n_dims_s, n_dims_z) + + n_nodes = 100 + n_edges = 500 + g = rand_graph(n_nodes, n_edges) + s = randn(Float32, n_dims_s, n_nodes) + z = randn(Float32, n_dims_z, n_edges) + + # check returned type and size + c = randn(Float32, 3, n_nodes) * 1000 + x1 = c .+ randn(Float32, 3, n_nodes) + x2 = c .+ randn(Float32, 3, n_nodes) + x3 = c .+ randn(Float32, 3, n_nodes) + rigid1 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...) + @test ipa(g, s, z, rigid1) isa Matrix{Float32} + @test size(ipa(g, s, z, rigid1)) == (n_dims_s, n_nodes) + + # check invariance + R, t = rand(RotMatrix{3,Float32}), randn(Float32, 3) + x1 = R * x1 .+ t + x2 = R * x2 .+ t + x3 = R * x3 .+ t + rigid2 = MessagePassingIPA.RigidTransformation(MessagePassingIPA.rigid_from_3points(x1, x2, x3)...) + @test ipa(g, s, z, rigid1) ≈ ipa(g, s, z, rigid2) + end end