From 9338ed7bfcd41d6c58973329dd2510f85a9cf656 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 28 Jul 2024 13:17:28 +0200 Subject: [PATCH] use GNNlib in GNN.jl (#464) * use GNNlib in GNN.jl * cleanup * ported all graph convs * workflow * fix * fix gcn_con * fix gcn_con * add comments --- .../workflows/test_GraphNeuralNetworks.yml | 2 +- GNNlib/src/GNNlib.jl | 107 ++- GNNlib/src/layers/conv.jl | 134 +++- Project.toml | 10 +- ext/GraphNeuralNetworksCUDAExt.jl | 49 -- src/GraphNeuralNetworks.jl | 114 ++-- src/deprecations.jl | 4 +- src/layers/basic.jl | 5 +- src/layers/conv.jl | 640 +----------------- src/msgpass.jl | 284 -------- src/utils.jl | 126 ---- test/layers/conv.jl | 2 +- 12 files changed, 265 insertions(+), 1212 deletions(-) delete mode 100644 ext/GraphNeuralNetworksCUDAExt.jl delete mode 100644 src/msgpass.jl delete mode 100644 src/utils.jl diff --git a/.github/workflows/test_GraphNeuralNetworks.yml b/.github/workflows/test_GraphNeuralNetworks.yml index a9c04a93d..49e5e6074 100644 --- a/.github/workflows/test_GraphNeuralNetworks.yml +++ b/.github/workflows/test_GraphNeuralNetworks.yml @@ -37,7 +37,7 @@ jobs: # dev mono repo versions pkg"registry up" Pkg.update() - pkg"dev ./GNNGraphs ." + pkg"dev ./GNNGraphs ./GNNlib ." Pkg.test("GraphNeuralNetworks"; coverage=true) - uses: julia-actions/julia-processcoverage@v1 with: diff --git a/GNNlib/src/GNNlib.jl b/GNNlib/src/GNNlib.jl index 8acb1e4c4..a84253776 100644 --- a/GNNlib/src/GNNlib.jl +++ b/GNNlib/src/GNNlib.jl @@ -12,70 +12,65 @@ using .GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, check_num_nodes, check_num_edges, EType, NType # for heteroconvs -export - # utils - reduce_nodes, - reduce_edges, - softmax_nodes, - softmax_edges, - broadcast_nodes, - broadcast_edges, - softmax_edge_neighbors, - # msgpass - apply_edges, - aggregate_neighbors, - propagate, - copy_xj, - copy_xi, - xi_dot_xj, - xi_sub_xj, - xj_sub_xi, - e_mul_xj, - w_mul_xj +include("utils.jl") +export reduce_nodes, + reduce_edges, + softmax_nodes, + softmax_edges, + broadcast_nodes, + broadcast_edges, + softmax_edge_neighbors +include("msgpass.jl") +export apply_edges, + aggregate_neighbors, + propagate, + copy_xj, + copy_xi, + xi_dot_xj, + xi_sub_xj, + xj_sub_xi, + e_mul_xj, + w_mul_xj + ## The following methods are defined but not exported -# # layers/basic -# dot_decoder, - -# # layers/conv -# agnn_conv, -# cg_conv, -# cheb_conv, -# edge_conv, -# egnn_conv, -# gat_conv, -# gatv2_conv, -# gated_graph_conv, -# gcn_conv, -# gin_conv, -# gmm_conv, -# graph_conv, -# megnet_conv, -# nn_conv, -# res_gated_graph_conv, -# sage_conv, -# sg_conv, -# transformer_conv, +include("layers/basic.jl") +export dot_decoder -# # layers/temporalconv -# a3tgcn_conv, +include("layers/conv.jl") +export agnn_conv, + cg_conv, + cheb_conv, + d_conv, + edge_conv, + egnn_conv, + gat_conv, + gatv2_conv, + gated_graph_conv, + gcn_conv, + gin_conv, + gmm_conv, + graph_conv, + megnet_conv, + nn_conv, + res_gated_graph_conv, + sage_conv, + sg_conv, + tag_conv, + transformer_conv -# # layers/pool -# global_pool, -# global_attention_pool, -# set2set_pool, -# topk_pool, -# topk_index, +include("layers/temporalconv.jl") +export a3tgcn_conv +include("layers/pool.jl") +export global_pool, + global_attention_pool, + set2set_pool, + topk_pool, + topk_index -include("utils.jl") -include("layers/basic.jl") -include("layers/conv.jl") # include("layers/heteroconv.jl") # no functional part at the moment -include("layers/temporalconv.jl") -include("layers/pool.jl") -include("msgpass.jl") end #module \ No newline at end of file diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index ba9274609..cd3606291 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -1,3 +1,4 @@ +####################### GCNConv ###################################### check_gcnconv_input(g::AbstractGNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) = throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs")) @@ -71,11 +72,14 @@ function gcn_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, norm_fn::F, conv_w end # when we also have edge_weight we need to convert the graph to COO -function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector, norm_fn::F) where F +function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x, edge_weight::EW, norm_fn::F, conv_weight::CW) where + {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO - return gcn_conv(l, g, x, edge_weight, norm_fn) + return gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight) end +####################### ChebConv ###################################### + function cheb_conv(l, g::GNNGraph, X::AbstractMatrix{T}) where {T} check_num_nodes(g, X) @assert size(X, 1) == size(l.weight, 2) "Input feature size must match input channel size." @@ -93,6 +97,8 @@ function cheb_conv(l, g::GNNGraph, X::AbstractMatrix{T}) where {T} return Y .+ l.bias end +####################### GraphConv ###################################### + function graph_conv(l, g::AbstractGNNGraph, x) check_num_nodes(g, x) xj, xi = expand_srcdst(g, x) @@ -101,6 +107,8 @@ function graph_conv(l, g::AbstractGNNGraph, x) return l.σ.(x .+ l.bias) end +####################### GATConv ###################################### + function gat_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" @@ -157,6 +165,8 @@ function gat_message(l, Wxi, Wxj, e) return (; logα, Wxj) end +####################### GATv2Conv ###################################### + function gatv2_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" @@ -201,6 +211,7 @@ function gatv2_message(l, Wxi, Wxj, e) return (; logα, Wxj) end +####################### GatedGraphConv ###################################### # TODO PIRACY! remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521 @non_differentiable fill!(x...) @@ -221,6 +232,8 @@ function gated_graph_conv(l, g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real return H end +####################### EdgeConv ###################################### + function edge_conv(l, g::AbstractGNNGraph, x) check_num_nodes(g, x) xj, xi = expand_srcdst(g, x) @@ -232,6 +245,7 @@ end edge_conv_message(l, xi, xj, e) = l.nn(vcat(xi, xj .- xi)) +####################### GINConv ###################################### function gin_conv(l, g::AbstractGNNGraph, x) check_num_nodes(g, x) @@ -242,6 +256,8 @@ function gin_conv(l, g::AbstractGNNGraph, x) return l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m) end +####################### NNConv ###################################### + function nn_conv(l, g::GNNGraph, x::AbstractMatrix, e) check_num_nodes(g, x) message = Fix1(nn_conv_message, l) @@ -257,6 +273,8 @@ function nn_conv_message(l, xi, xj, e) return reshape(m, :, nedges) end +####################### SAGEConv ###################################### + function sage_conv(l, g::AbstractGNNGraph, x) check_num_nodes(g, x) xj, xi = expand_srcdst(g, x) @@ -265,6 +283,8 @@ function sage_conv(l, g::AbstractGNNGraph, x) return x end +####################### ResGatedConv ###################################### + function res_gated_graph_conv(l, g::AbstractGNNGraph, x) check_num_nodes(g, x) xj, xi = expand_srcdst(g, x) @@ -280,6 +300,8 @@ function res_gated_graph_conv(l, g::AbstractGNNGraph, x) return l.σ.(l.U * xi .+ m .+ l.bias) end +####################### CGConv ###################################### + function cg_conv(l, g::AbstractGNNGraph, x, e::Union{Nothing, AbstractMatrix} = nothing) check_num_nodes(g, x) xj, xi = expand_srcdst(g, x) @@ -311,6 +333,7 @@ function cg_message(l, xi, xj, e) return l.dense_f(z) .* l.dense_s(z) end +####################### AGNNConv ###################################### function agnn_conv(l, g::GNNGraph, x::AbstractMatrix) check_num_nodes(g, x) @@ -329,6 +352,8 @@ function agnn_conv(l, g::GNNGraph, x::AbstractMatrix) return x end +####################### MegNetConv ###################################### + function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) check_num_nodes(g, x) @@ -343,6 +368,8 @@ function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) return x̄, ē end +####################### GMMConv ###################################### + function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) (nin, ein), out = l.ch #Notational Simplicity @@ -374,6 +401,8 @@ function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) return m end +####################### SGCConv ###################################### + # this layer is not stable enough to be supported by GNNHeteroGraph type # due to it's looping mechanism function sgc_conv(l, g::GNNGraph, x::AbstractMatrix{T}, @@ -425,6 +454,8 @@ function sgc_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, return sgc_conv(l, g, x, edge_weight) end +####################### EGNNGConv ###################################### + function egnn_conv(l, g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = nothing) if l.num_features.edge > 0 @assert e!==nothing "Edge features must be provided." @@ -463,6 +494,8 @@ function egnn_message(l, xi, xj, e) return (; x = msg_x, h = msg_h) end +######################## SGConv ###################################### + # this layer is not stable enough to be supported by GNNHeteroGraph type # due to it's looping mechanism function sg_conv(l, g::GNNGraph, x::AbstractMatrix{T}, @@ -514,6 +547,8 @@ function sg_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, return sg_conv(l, g, x, edge_weight) end +######################## TransformerConv ###################################### + function transformer_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing} = nothing) check_num_nodes(g, x) @@ -591,3 +626,98 @@ function transformer_message_main(xi, xj, e) end return e.α .* val end + + +######################## TAGConv ###################################### + +function tag_conv(l, g::GNNGraph, x::AbstractMatrix{T}, + edge_weight::EW = nothing) where + {T, EW <: Union{Nothing, AbstractVector}} + @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" + + if edge_weight !== nothing + @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) + end + c = 1 ./ sqrt.(d) + + sum_pow = 0 + sum_total = 0 + for iter in 1:(l.k) + x = x .* c' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj = x) + else + x = propagate(copy_xj, g, +, xj = x) + end + x = x .* c' + + # On the first iteration, initialize sum_pow with the first propagated features + # On subsequent iterations, accumulate propagated features + if iter == 1 + sum_pow = x + sum_total = l.weight * sum_pow + else + sum_pow += x + # Weighted sum of features for each power of adjacency matrix + # This applies the weight matrix to the accumulated sum of propagated features + sum_total += l.weight * sum_pow + end + end + + return (sum_total .+ l.bias) +end + +function tag_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, + edge_weight::AbstractVector) + g = GNNGraph(edge_index(g)...; g.num_nodes) + return l(g, x, edge_weight) +end + +######################## DConv ###################################### + +function d_conv(l, g::GNNGraph, x::AbstractMatrix) + #A = adjacency_matrix(g, weighted = true) + s, t = edge_index(g) + gt = GNNGraph(t, s, get_edge_weight(g)) + deg_out = degree(g; dir = :out) + deg_in = degree(g; dir = :in) + deg_out = Diagonal(deg_out) + deg_in = Diagonal(deg_in) + + h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x + + T0 = x + if l.K > 1 + # T1_in = T0 * deg_in * A' + #T1_out = T0 * deg_out' * A + T1_out = propagate(w_mul_xj, g, +; xj = T0*deg_out') + T1_in = propagate(w_mul_xj, gt, +; xj = T0*deg_in) + h = h .+ l.weights[1,2,:,:] * T1_in .+ l.weights[2,2,:,:] * T1_out + end + for i in 2:l.K + T2_in = propagate(w_mul_xj, gt, +; xj = T1_in*deg_in) + T2_in = 2 * T2_in - T0 + T2_out = propagate(w_mul_xj, g ,+; xj = T1_out*deg_out') + T2_out = 2 * T2_out - T0 + h = h .+ l.weights[1,i,:,:] * T2_in .+ l.weights[2,i,:,:] * T2_out + T1_in = T2_in + T1_out = T2_out + end + return h .+ l.bias +end diff --git a/Project.toml b/Project.toml index 51b2e2f15..7a429a998 100644 --- a/Project.toml +++ b/Project.toml @@ -9,9 +9,10 @@ DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" +GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -20,8 +21,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -[extensions] -GraphNeuralNetworksCUDAExt = "CUDA" +# [extensions] +# GraphNeuralNetworksCUDAExt = "CUDA" [compat] CUDA = "4, 5" @@ -30,9 +31,10 @@ DataStructures = "0.18" Flux = "0.14" Functors = "0.4.1" GNNGraphs = "1.0" +GNNlib = "0.2" LinearAlgebra = "1" -MacroTools = "0.5" MLUtils = "0.4" +MacroTools = "0.5" NNlib = "0.9" Random = "1" Reexport = "1" diff --git a/ext/GraphNeuralNetworksCUDAExt.jl b/ext/GraphNeuralNetworksCUDAExt.jl deleted file mode 100644 index cf11f4dec..000000000 --- a/ext/GraphNeuralNetworksCUDAExt.jl +++ /dev/null @@ -1,49 +0,0 @@ -module GraphNeuralNetworksCUDAExt - -using CUDA -using Random, Statistics, LinearAlgebra -using GraphNeuralNetworks -using GNNGraphs -using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T -import GraphNeuralNetworks: propagate - -const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix} - -###### PROPAGATE SPECIALIZATIONS #################### - -## COPY_XJ - -## avoid the fast path on gpu until we have better cuda support -function propagate(::typeof(copy_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), - xi, xj::AnyCuMatrix, e) - propagate((xi, xj, e) -> copy_xj(xi, xj, e), g, +, xi, xj, e) -end - -## E_MUL_XJ - -## avoid the fast path on gpu until we have better cuda support -function propagate(::typeof(e_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), - xi, xj::AnyCuMatrix, e::AbstractVector) - propagate((xi, xj, e) -> e_mul_xj(xi, xj, e), g, +, xi, xj, e) -end - -## W_MUL_XJ - -## avoid the fast path on gpu until we have better cuda support -function propagate(::typeof(w_mul_xj), g::GNNGraph{<:Union{COO_T, SPARSE_T}}, ::typeof(+), - xi, xj::AnyCuMatrix, e::Nothing) - propagate((xi, xj, e) -> w_mul_xj(xi, xj, e), g, +, xi, xj, e) -end - -# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) -# A = adjacency_matrix(g, weighted=false) -# D = compute_degree(A) -# return xj * A * D -# end - -# # Zygote bug. Error with sparse matrix without nograd -# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) - -# Flux.Zygote.@nograd compute_degree - -end #module diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 0471e6555..0f5c43a18 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -12,88 +12,58 @@ using Reexport using DataStructures: nlargest using MLUtils: zeros_like -@reexport using GNNGraphs using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, check_num_nodes, check_num_edges, EType, NType # for heteroconvs -export -# utils - reduce_nodes, - reduce_edges, - softmax_nodes, - softmax_edges, - broadcast_nodes, - broadcast_edges, - softmax_edge_neighbors, - -# msgpass - apply_edges, - aggregate_neighbors, - propagate, - copy_xj, - copy_xi, - xi_dot_xj, - xi_sub_xj, - xj_sub_xi, - e_mul_xj, - w_mul_xj, - -# layers/basic - GNNLayer, - GNNChain, - WithGraph, - DotDecoder, - -# layers/conv - AGNNConv, - CGConv, - ChebConv, - EdgeConv, - EGNNConv, - GATConv, - GATv2Conv, - GatedGraphConv, - GCNConv, - GINConv, - GMMConv, - GraphConv, - MEGNetConv, - NNConv, - ResGatedGraphConv, - SAGEConv, - SGConv, - TAGConv, - TransformerConv, - DConv, - -# layers/heteroconv - HeteroGraphConv, - -# layers/temporalconv - TGCN, - A3TGCN, - GConvLSTM, - GConvGRU, - DCGRU, - -# layers/pool - GlobalPool, - GlobalAttentionPool, - Set2Set, - TopKPool, - topk_index, - -# mldatasets - mldataset2gnngraph +@reexport using GNNGraphs +@reexport using GNNlib -include("utils.jl") include("layers/basic.jl") +export GNNLayer, + GNNChain, + WithGraph, + DotDecoder + include("layers/conv.jl") +export AGNNConv, + CGConv, + ChebConv, + DConv, + EdgeConv, + EGNNConv, + GATConv, + GATv2Conv, + GatedGraphConv, + GCNConv, + GINConv, + GMMConv, + GraphConv, + MEGNetConv, + NNConv, + ResGatedGraphConv, + SAGEConv, + SGConv, + TAGConv, + TransformerConv + include("layers/heteroconv.jl") +export HeteroGraphConv + include("layers/temporalconv.jl") +export TGCN, + A3TGCN, + GConvLSTM, + GConvGRU, + DCGRU + include("layers/pool.jl") -include("msgpass.jl") +export GlobalPool, + GlobalAttentionPool, + Set2Set, + TopKPool, + topk_index + include("deprecations.jl") end diff --git a/src/deprecations.jl b/src/deprecations.jl index 5416e9581..28f6532cc 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,2 +1,4 @@ -@deprecate AGNNConv(init_beta) AGNNConv(; init_beta) +# V1.0 deprecations +# TODO doe some reason this is not working +# @deprecate (l::GCNConv)(g, x, edge_weight, norm_fn; conv_weight=nothing) l(g, x, edge_weight; norm_fn, conv_weight) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index de200a802..c9322f0ec 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -215,7 +215,4 @@ julia> dotdec(g, rand(2, 5)) """ struct DotDecoder <: GNNLayer end -function (::DotDecoder)(g, x) - check_num_nodes(g, x) - return apply_edges(xi_dot_xj, g, xi = x, xj = x) -end +(::DotDecoder)(g, x) = GNNlib.dot_decoder(g, x) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index ed6aca3e0..ddfa4e945 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -32,7 +32,7 @@ and optionally an edge weight vector. # Forward - (::GCNConv)(g::GNNGraph, x::AbstractMatrix, edge_weight = nothing, norm_fn::Function = d -> 1 ./ sqrt.(d), conv_weight::Union{Nothing,AbstractMatrix} = nothing) -> AbstractMatrix + (::GCNConv)(g::GNNGraph, x, edge_weight = nothing; norm_fn = d -> 1 ./ sqrt.(d), conv_weight = nothing) -> AbstractMatrix Takes as input a graph `g`, a node feature matrix `x` of size `[in, num_nodes]`, and optionally an edge weight vector. Returns a node feature matrix of size @@ -60,7 +60,7 @@ y = l(g, x) # size: 5 × num_nodes # convolution with edge weights and custom normalization function w = [1.1, 0.1, 2.3, 0.5] custom_norm_fn(d) = 1 ./ sqrt.(d + 1) # Custom normalization function -y = l(g, x, w, custom_norm_fn) +y = l(g, x, w; norm_fn = custom_norm_fn) # Edge weights can also be embedded in the graph. g = GNNGraph(s, t, w) @@ -89,88 +89,14 @@ function GCNConv(ch::Pair{Int, Int}, σ = identity; GCNConv(W, b, σ, add_self_loops, use_edge_weight) end -check_gcnconv_input(g::AbstractGNNGraph{<:ADJMAT_T}, edge_weight::AbstractVector) = - throw(ArgumentError("Providing external edge_weight is not yet supported for adjacency matrix graphs")) -function check_gcnconv_input(g::AbstractGNNGraph, edge_weight::AbstractVector) - if length(edge_weight) !== g.num_edges - throw(ArgumentError("Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))")) - end -end - -check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing - -function (l::GCNConv)(g::AbstractGNNGraph, - x, - edge_weight::EW = nothing, - norm_fn::Function = d -> 1 ./ sqrt.(d); - conv_weight::Union{Nothing,AbstractMatrix} = nothing - ) where {EW <: Union{Nothing, AbstractVector}} - - check_gcnconv_input(g, edge_weight) - - if conv_weight === nothing - weight = l.weight - else - weight = conv_weight - if size(weight) != size(l.weight) - throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(l.weight)) but got $(size(weight))")) - end - end - - if l.add_self_loops - g = add_self_loops(g) - if edge_weight !== nothing - # Pad weights with ones - # TODO for ADJMAT_T the new edges are not generally at the end - edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] - @assert length(edge_weight) == g.num_edges - end - end - Dout, Din = size(weight) - if Dout < Din && !(g isa GNNHeteroGraph) - # multiply before convolution if it is more convenient, otherwise multiply after - # (this works only for homogenous graph) - x = weight * x - end +function (l::GCNConv)(g, x, edge_weight = nothing; + norm_fn = d -> 1 ./ sqrt.(d), + conv_weight = nothing) - xj, xi = expand_srcdst(g, x) # expand only after potential multiplication - T = eltype(xi) - - if g isa GNNHeteroGraph - din = degree(g, g.etypes[1], T; dir = :in) - dout = degree(g, g.etypes[1], T; dir = :out) - - cout = norm_fn(dout) - cin = norm_fn(din) - else - if edge_weight !== nothing - d = degree(g, T; dir = :in, edge_weight) - else - d = degree(g, T; dir = :in, edge_weight = l.use_edge_weight) - end - cin = cout = norm_fn(d) - end - xj = xj .* cout' - if edge_weight !== nothing - x = propagate(e_mul_xj, g, +, xj = xj, e = edge_weight) - elseif l.use_edge_weight - x = propagate(w_mul_xj, g, +, xj = xj) - else - x = propagate(copy_xj, g, +, xj = xj) - end - x = x .* cin' - if Dout >= Din || g isa GNNHeteroGraph - x = weight * x - end - return l.σ.(x .+ l.bias) + return GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight) end -function (l::GCNConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, - edge_weight::AbstractVector, norm_fn::Function) - g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO - return l(g, x, edge_weight, norm_fn) -end function Base.show(io::IO, l::GCNConv) out, in = size(l.weight) @@ -243,22 +169,7 @@ end @functor ChebConv -function (c::ChebConv)(g::GNNGraph, X::AbstractMatrix{T}) where {T} - check_num_nodes(g, X) - @assert size(X, 1)==size(c.weight, 2) "Input feature size must match input channel size." - - L̃ = scaled_laplacian(g, eltype(X)) - - Z_prev = X - Z = X * L̃ - Y = view(c.weight, :, :, 1) * Z_prev - Y += view(c.weight, :, :, 2) * Z - for k in 3:(c.k) - Z, Z_prev = 2 * Z * L̃ - Z_prev, Z - Y += view(c.weight, :, :, k) * Z - end - return Y .+ c.bias -end +(l::ChebConv)(g, x) = GNNlib.cheb_conv(l, g, x) function Base.show(io::IO, l::ChebConv) out, in, k = size(l.weight) @@ -325,13 +236,7 @@ function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +, GraphConv(W1, W2, b, σ, aggr) end -function (l::GraphConv)(g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - m = propagate(copy_xj, g, l.aggr, xj = xj) - x = l.σ.(l.weight1 * xi .+ l.weight2 * m .+ l.bias) - return x -end +(l::GraphConv)(g, x) = GNNlib.graph_conv(l, g, x) function Base.show(io::IO, l::GraphConv) in_channel = size(l.weight1, ndims(l.weight1)) @@ -433,61 +338,7 @@ end (l::GATConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) -function (l::GATConv)(g::AbstractGNNGraph, x, - e::Union{Nothing, AbstractMatrix} = nothing) - check_num_nodes(g, x) - @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" - @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" - - xj, xi = expand_srcdst(g, x) - - if l.add_self_loops - @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." - g = add_self_loops(g) - end - - _, chout = l.channel - heads = l.heads - - Wxi = Wxj = l.dense_x(xj) - Wxi = Wxj = reshape(Wxj, chout, heads, :) - - if xi !== xj - Wxi = l.dense_x(xi) - Wxi = reshape(Wxi, chout, heads, :) - end - - # a hand-written message passing - m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e) - α = softmax_edge_neighbors(g, m.logα) - α = dropout(α, l.dropout) - β = α .* m.Wxj - x = aggregate_neighbors(g, +, β) - - if !l.concat - x = mean(x, dims = 2) - end - x = reshape(x, :, size(x, 3)) # return a matrix - x = l.σ.(x .+ l.bias) - - return x -end - -function message(l::GATConv, Wxi, Wxj, e) - _, chout = l.channel - heads = l.heads - - if e === nothing - Wxx = vcat(Wxi, Wxj) - else - We = l.dense_e(e) - We = reshape(We, chout, heads, :) # chout × nheads × nnodes - Wxx = vcat(Wxi, Wxj, We) - end - aWW = sum(l.a .* Wxx, dims = 1) # 1 × nheads × nedges - logα = leakyrelu.(aWW, l.negative_slope) - return (; logα, Wxj) -end +(l::GATConv)(g, x, e = nothing) = GNNlib.gat_conv(l, g, x, e) function Base.show(io::IO, l::GATConv) (in, ein), out = l.channel @@ -607,49 +458,7 @@ end (l::GATv2Conv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) -function (l::GATv2Conv)(g::AbstractGNNGraph, x, - e::Union{Nothing, AbstractMatrix} = nothing) - check_num_nodes(g, x) - @assert !((e === nothing) && (l.dense_e !== nothing)) "Input edge features required for this layer" - @assert !((e !== nothing) && (l.dense_e === nothing)) "Input edge features were not specified in the layer constructor" - - xj, xi = expand_srcdst(g, x) - - if l.add_self_loops - @assert e===nothing "Using edge features and setting add_self_loops=true at the same time is not yet supported." - g = add_self_loops(g) - end - _, out = l.channel - heads = l.heads - - Wxi = reshape(l.dense_i(xi), out, heads, :) # out × heads × nnodes - Wxj = reshape(l.dense_j(xj), out, heads, :) # out × heads × nnodes - - m = apply_edges((xi, xj, e) -> message(l, xi, xj, e), g, Wxi, Wxj, e) - α = softmax_edge_neighbors(g, m.logα) - α = dropout(α, l.dropout) - β = α .* m.Wxj - x = aggregate_neighbors(g, +, β) - - if !l.concat - x = mean(x, dims = 2) - end - x = reshape(x, :, size(x, 3)) - x = l.σ.(x .+ l.bias) - return x -end - -function message(l::GATv2Conv, Wxi, Wxj, e) - _, out = l.channel - heads = l.heads - - Wx = Wxi + Wxj # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?" - if e !== nothing - Wx += reshape(l.dense_e(e), out, heads, :) - end - logα = sum(l.a .* leakyrelu.(Wx, l.negative_slope), dims = 1) # 1 × heads × nedges - return (; logα, Wxj) -end +(l::GATv2Conv)(g, x, e=nothing) = GNNlib.gatv2_conv(l, g, x, e) function Base.show(io::IO, l::GATv2Conv) (in, ein), out = l.channel @@ -715,24 +524,8 @@ function GatedGraphConv(out_ch::Int, num_layers::Int; GatedGraphConv(w, gru, out_ch, num_layers, aggr) end -# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521 -@non_differentiable fill!(x...) -function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {S <: Real} - check_num_nodes(g, H) - m, n = size(H) - @assert (m<=l.out_ch) "number of input features must less or equals to output features." - if m < l.out_ch - Hpad = similar(H, S, l.out_ch - m, n) - H = vcat(H, fill!(Hpad, 0)) - end - for i in 1:(l.num_layers) - M = view(l.weight, :, :, i) * H - M = propagate(copy_xj, g, l.aggr; xj = M) - H, _ = l.gru(H, M) - end - H -end +(l::GatedGraphConv)(g, H) = GNNlib.gated_graph_conv(l, g, H) function Base.show(io::IO, l::GatedGraphConv) print(io, "GatedGraphConv(($(l.out_ch) => $(l.out_ch))^$(l.num_layers)") @@ -783,15 +576,7 @@ end EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr) -function (l::EdgeConv)(g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - - message(l, xi, xj, e) = l.nn(vcat(xi, xj .- xi)) - - x = propagate(message, g, l.aggr, l, xi = xi, xj = xj, e = nothing) - return x -end +(l::EdgeConv)(g, x) = GNNlib.edge_conv(l, g, x) function Base.show(io::IO, l::EdgeConv) print(io, "EdgeConv(", l.nn) @@ -846,14 +631,7 @@ Flux.trainable(l::GINConv) = (nn = l.nn,) GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) -function (l::GINConv)(g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - - m = propagate(copy_xj, g, l.aggr, xj = xj) - - l.nn((1 .+ ofeltype(xi, l.ϵ)) .* xi .+ m) -end +(l::GINConv)(g, x) = GNNlib.gin_conv(l, g, x) function Base.show(io::IO, l::GINConv) print(io, "GINConv($(l.nn)") @@ -929,20 +707,7 @@ function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true, return NNConv(W, b, nn, σ, aggr) end -function (l::NNConv)(g::GNNGraph, x::AbstractMatrix, e) - check_num_nodes(g, x) - - m = propagate(message, g, l.aggr, l, xj = x, e = e) - return l.σ.(l.weight * x .+ m .+ l.bias) -end - -function message(l::NNConv, xi, xj, e) - nin, nedges = size(xj) - W = reshape(l.nn(e), (:, nin, nedges)) - xj = reshape(xj, (nin, 1, nedges)) # needed by batched_mul - m = NNlib.batched_mul(W, xj) - return reshape(m, :, nedges) -end +(l::NNConv)(g, x, e) = GNNlib.nn_conv(l, g, x, e) (l::NNConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) @@ -1008,13 +773,7 @@ function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, SAGEConv(W, b, σ, aggr) end -function (l::SAGEConv)(g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - m = propagate(copy_xj, g, l.aggr, xj = xj) - x = l.σ.(l.weight * vcat(xi, m) .+ l.bias) - return x -end +(l::SAGEConv)(g, x) = GNNlib.sage_conv(l, g, x) function Base.show(io::IO, l::SAGEConv) out_channel, in_channel = size(l.weight) @@ -1087,20 +846,7 @@ function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity; return ResGatedGraphConv(A, B, U, V, b, σ) end -function (l::ResGatedGraphConv)(g::AbstractGNNGraph, x) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - - message(xi, xj, e) = sigmoid.(xi.Ax .+ xj.Bx) .* xj.Vx - - Ax = l.A * xi - Bx = l.B * xj - Vx = l.V * xj - - m = propagate(message, g, +, xi = (; Ax), xj = (; Bx, Vx)) - - return l.σ.(l.U * xi .+ m .+ l.bias) -end +(l::ResGatedGraphConv)(g, x) = GNNlib.res_gated_graph_conv(l, g, x) function Base.show(io::IO, l::ResGatedGraphConv) out_channel, in_channel = size(l.A) @@ -1173,37 +919,8 @@ function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false, return CGConv(ch, dense_f, dense_s, residual) end -function (l::CGConv)(g::AbstractGNNGraph, x, - e::Union{Nothing, AbstractMatrix} = nothing) - check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - - if e !== nothing - check_num_edges(g, e) - end - - m = propagate(message, g, +, l, xi = xi, xj = xj, e = e) - - if l.residual - if size(x, 1) == size(m, 1) - m += x - else - @warn "number of output features different from number of input features, residual not applied." - end - end - - return m -end - +(l::CGConv)(g, x, e = nothing) = GNNlib.cg_conv(l, g, x, e) -function message(l::CGConv, xi, xj, e) - if e !== nothing - z = vcat(xi, xj, e) - else - z = vcat(xi, xj) - end - return l.dense_f(z) .* l.dense_s(z) -end (l::CGConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) @@ -1271,22 +988,7 @@ function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true) AGNNConv([init_beta], add_self_loops, trainable) end -function (l::AGNNConv)(g::GNNGraph, x::AbstractMatrix) - check_num_nodes(g, x) - if l.add_self_loops - g = add_self_loops(g) - end - - xn = x ./ sqrt.(sum(x .^ 2, dims = 1)) - cos_dist = apply_edges(xi_dot_xj, g, xi = xn, xj = xn) - α = softmax_edge_neighbors(g, l.β .* cos_dist) - - x = propagate(g, +; xj = x, e = α) do xi, xj, α - α .* xj - end - - return x -end +(l::AGNNConv)(g, x) = GNNlib.agnn_conv(l, g, x) @doc raw""" MEGNetConv(ϕe, ϕv; aggr=mean) @@ -1337,27 +1039,15 @@ function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) ϕv = Chain(Dense(nin + nout, nout, relu), Dense(nout, nout)) - MEGNetConv(ϕe, ϕv; aggr) + return MEGNetConv(ϕe, ϕv; aggr) end function (l::MEGNetConv)(g::GNNGraph) x, e = l(g, node_features(g), edge_features(g)) - g = GNNGraph(g, ndata = x, edata = e) + return GNNGraph(g, ndata = x, edata = e) end -function (l::MEGNetConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) - check_num_nodes(g, x) - - ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e - l.ϕe(vcat(xi, xj, e)) - end - - xᵉ = aggregate_neighbors(g, l.aggr, ē) - - x̄ = l.ϕv(vcat(x, xᵉ)) - - return x̄, ē -end +(l::MEGNetConv)(g, x, e) = GNNlib.megnet_conv(l, g, x, e) @doc raw""" GMMConv((in, ein) => out, σ=identity; K=1, bias=true, init=glorot_uniform, residual=false) @@ -1434,36 +1124,7 @@ function GMMConv(ch::Pair{NTuple{2, Int}, Int}, GMMConv(mu, sigma_inv, b, σ, ch, K, dense_x, residual) end -function (l::GMMConv)(g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) - (nin, ein), out = l.ch #Notational Simplicity - - @assert (ein == size(e)[1]&&g.num_edges == size(e)[2]) "Pseudo-cordinate dimension is not equal to (ein,num_edge)" - - num_edges = g.num_edges - w = reshape(e, (ein, 1, num_edges)) - mu = reshape(l.mu, (ein, l.K, 1)) - - w = @. ((w - mu)^2) / 2 - w = w .* reshape(l.sigma_inv .^ 2, (ein, l.K, 1)) - w = exp.(sum(w, dims = 1)) # (1, K, num_edge) - - xj = reshape(l.dense_x(x), (out, l.K, :)) # (out, K, num_nodes) - - m = propagate(e_mul_xj, g, mean, xj = xj, e = w) - m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes) - - m = l.σ(m .+ l.bias) - - if l.residual - if size(x, 1) == size(m, 1) - m += x - else - @warn "Residual not applied : output feature is not equal to input_feature" - end - end - - return m -end +(l::GMMConv)(g::GNNGraph, x, e) = GNNlib.gmm_conv(l, g, x, e) (l::GMMConv)(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) @@ -1543,56 +1204,7 @@ function SGConv(ch::Pair{Int, Int}, k = 1; SGConv(W, b, k, add_self_loops, use_edge_weight) end -# this layer is not stable enough to be supported by GNNHeteroGraph type -# due to it's looping mechanism -function (l::SGConv)(g::GNNGraph, x::AbstractMatrix{T}, - edge_weight::EW = nothing) where - {T, EW <: Union{Nothing, AbstractVector}} - @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" - - if edge_weight !== nothing - @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" - end - - if l.add_self_loops - g = add_self_loops(g) - if edge_weight !== nothing - edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] - @assert length(edge_weight) == g.num_edges - end - end - Dout, Din = size(l.weight) - if Dout < Din - x = l.weight * x - end - if edge_weight !== nothing - d = degree(g, T; dir = :in, edge_weight) - else - d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) - end - c = 1 ./ sqrt.(d) - for iter in 1:(l.k) - x = x .* c' - if edge_weight !== nothing - x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) - elseif l.use_edge_weight - x = propagate(w_mul_xj, g, +, xj = x) - else - x = propagate(copy_xj, g, +, xj = x) - end - x = x .* c' - end - if Dout >= Din - x = l.weight * x - end - return (x .+ l.bias) -end - -function (l::SGConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, - edge_weight::AbstractVector) - g = GNNGraph(edge_index(g)...; g.num_nodes) - return l(g, x, edge_weight) -end +(l::SGConv)(g, x, edge_weight = nothing) = GNNlib.sg_conv(l, g, x, edge_weight) function Base.show(io::IO, l::SGConv) out, in = size(l.weight) @@ -1647,7 +1259,6 @@ struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer use_edge_weight::Bool end - @functor TAGConv function TAGConv(ch::Pair{Int, Int}, k = 3; @@ -1661,64 +1272,7 @@ function TAGConv(ch::Pair{Int, Int}, k = 3; TAGConv(W, b, k, add_self_loops, use_edge_weight) end -function (l::TAGConv)(g::GNNGraph, x::AbstractMatrix{T}, - edge_weight::EW = nothing) where - {T, EW <: Union{Nothing, AbstractVector}} - @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" - - if edge_weight !== nothing - @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" - end - - if l.add_self_loops - g = add_self_loops(g) - if edge_weight !== nothing - edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] - @assert length(edge_weight) == g.num_edges - end - end - Dout, Din = size(l.weight) - if edge_weight !== nothing - d = degree(g, T; dir = :in, edge_weight) - else - d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) - end - c = 1 ./ sqrt.(d) - - sum_pow = 0 - sum_total = 0 - for iter in 1:(l.k) - x = x .* c' - if edge_weight !== nothing - x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) - elseif l.use_edge_weight - x = propagate(w_mul_xj, g, +, xj = x) - else - x = propagate(copy_xj, g, +, xj = x) - end - x = x .* c' - - # On the first iteration, initialize sum_pow with the first propagated features - # On subsequent iterations, accumulate propagated features - if iter == 1 - sum_pow = x - sum_total = l.weight * sum_pow - else - sum_pow += x - # Weighted sum of features for each power of adjacency matrix - # This applies the weight matrix to the accumulated sum of propagated features - sum_total += l.weight * sum_pow - end - end - - return (sum_total .+ l.bias) -end - -function (l::TAGConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, - edge_weight::AbstractVector) - g = GNNGraph(edge_index(g)...; g.num_nodes) - return l(g, x, edge_weight) -end +(l::TAGConv)(g, x, edge_weight = nothing) = GNNlib.tag_conv(l, g, x, edge_weight) function Base.show(io::IO, l::TAGConv) out, in = size(l.weight) @@ -1819,42 +1373,7 @@ function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1] return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) end -function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = nothing) - if l.num_features.edge > 0 - @assert e!==nothing "Edge features must be provided." - end - @assert size(h, 1)==l.num_features.in "Input features must match layer input size." - - x_diff = apply_edges(xi_sub_xj, g, x, x) - sqnorm_xdiff = sum(x_diff .^ 2, dims = 1) - x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1.0f-6) - - msg = apply_edges(message, g, l, - xi = (; h), xj = (; h), e = (; e, x_diff, sqnorm_xdiff)) - h_aggr = aggregate_neighbors(g, +, msg.h) - x_aggr = aggregate_neighbors(g, mean, msg.x) - - hnew = l.ϕh(vcat(h, h_aggr)) - if l.residual - h = h .+ hnew - else - h = hnew - end - x = x .+ x_aggr - return h, x -end - -function message(l::EGNNConv, xi, xj, e) - if l.num_features.edge > 0 - f = vcat(xi.h, xj.h, e.sqnorm_xdiff, e.e) - else - f = vcat(xi.h, xj.h, e.sqnorm_xdiff) - end - - msg_h = l.ϕe(f) - msg_x = l.ϕx(msg_h) .* e.x_diff - return (; x = msg_x, h = msg_h) -end +(l::EGNNConv)(g, h, x, e = nothing) = GNNlib.egnn_conv(l, g, h, x, e) function Base.show(io::IO, l::EGNNConv) ne = l.num_features.edge @@ -1961,7 +1480,7 @@ end @functor TransformerConv function Flux.trainable(l::TransformerConv) - (W1 = l.W1, W2 = l.W2, W3 = l.W3, W4 = l.W4, W5 = l.W5, W6 = l.W6, FF = l.FF, BN1 = l.BN1, BN2 = l.BN2) + (; l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2) end function TransformerConv(ch::Pair{Int, Int}, args...; kws...) @@ -2005,86 +1524,12 @@ function TransformerConv(ch::Pair{NTuple{2, Int}, Int}; Float32(√out)) end -function (l::TransformerConv)(g::GNNGraph, x::AbstractMatrix, - e::Union{AbstractMatrix, Nothing} = nothing) - check_num_nodes(g, x) - - if l.add_self_loops - g = add_self_loops(g) - end - - out = l.channels[2] - heads = l.heads - W1x = !isnothing(l.W1) ? l.W1(x) : nothing - W2x = reshape(l.W2(x), out, heads, :) - W3x = reshape(l.W3(x), out, heads, :) - W4x = reshape(l.W4(x), out, heads, :) - W6e = !isnothing(l.W6) ? reshape(l.W6(e), out, heads, :) : nothing - - m = apply_edges(message_uij, g, l; xi = (; W3x), xj = (; W4x), e = (; W6e)) - α = softmax_edge_neighbors(g, m) - α_val = propagate(message_main, g, +, l; xi = (; W3x), xj = (; W2x), e = (; W6e, α)) - - h = α_val - if l.concat - h = reshape(h, out * heads, :) # concatenate heads - else - h = mean(h, dims = 2) # average heads - h = reshape(h, out, :) - end - - if !isnothing(W1x) # root_weight - if !isnothing(l.W5) # gating - β = l.W5(vcat(h, W1x, h .- W1x)) - h = β .* W1x + (1.0f0 .- β) .* h - else - h += W1x - end - end - - if l.skip_connection - @assert size(h, 1)==size(x, 1) "In-channels must correspond to out-channels * heads if skip_connection is used" - h += x - end - if !isnothing(l.BN1) - h = l.BN1(h) - end - - if !isnothing(l.FF) - h1 = h - h = l.FF(h) - if l.skip_connection - h += h1 - end - if !isnothing(l.BN2) - h = l.BN2(h) - end - end - - return h -end +(l::TransformerConv)(g, x, e = nothing) = GNNlib.transformer_conv(l, g, x, e) function (l::TransformerConv)(g::GNNGraph) GNNGraph(g, ndata = l(g, node_features(g), edge_features(g))) end -function message_uij(l::TransformerConv, xi, xj, e) - key = xj.W4x - if !isnothing(e.W6e) - key += e.W6e - end - uij = sum(xi.W3x .* key, dims = 1) ./ l.sqrt_out - return uij -end - -function message_main(l::TransformerConv, xi, xj, e) - val = xj.W2x - if !isnothing(e.W6e) - val += e.W6e - end - return e.α .* val -end - function Base.show(io::IO, l::TransformerConv) (in, ein), out = l.channels print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))") @@ -2132,36 +1577,7 @@ function DConv(ch::Pair{Int, Int}, K::Int; init = glorot_uniform, bias = true) DConv(in, out, weights, b, K) end -function (l::DConv)(g::GNNGraph, x::AbstractMatrix) - #A = adjacency_matrix(g, weighted = true) - s, t = edge_index(g) - gt = GNNGraph(t, s, get_edge_weight(g)) - deg_out = degree(g; dir = :out) - deg_in = degree(g; dir = :in) - deg_out = Diagonal(deg_out) - deg_in = Diagonal(deg_in) - - h = l.weights[1,1,:,:] * x .+ l.weights[2,1,:,:] * x - - T0 = x - if l.K > 1 - # T1_in = T0 * deg_in * A' - #T1_out = T0 * deg_out' * A - T1_out = propagate(w_mul_xj, g, +; xj = T0*deg_out') - T1_in = propagate(w_mul_xj, gt, +; xj = T0*deg_in) - h = h .+ l.weights[1,2,:,:] * T1_in .+ l.weights[2,2,:,:] * T1_out - end - for i in 2:l.K - T2_in = propagate(w_mul_xj, gt, +; xj = T1_in*deg_in) - T2_in = 2 * T2_in - T0 - T2_out = propagate(w_mul_xj, g ,+; xj = T1_out*deg_out') - T2_out = 2 * T2_out - T0 - h = h .+ l.weights[1,i,:,:] * T2_in .+ l.weights[2,i,:,:] * T2_out - T1_in = T2_in - T1_out = T2_out - end - return h .+ l.bias -end +(l::DConv)(g, x) = GNNlib.d_conv(l, g, x) function Base.show(io::IO, l::DConv) print(io, "DConv($(l.in) => $(l.out), K=$(l.K))") diff --git a/src/msgpass.jl b/src/msgpass.jl deleted file mode 100644 index 2118cdac6..000000000 --- a/src/msgpass.jl +++ /dev/null @@ -1,284 +0,0 @@ -""" - propagate(fmsg, g, aggr [layer]; [xi, xj, e]) - propagate(fmsg, g, aggr, [layer,] xi, xj, e=nothing) - -Performs message passing on graph `g`. Takes care of materializing the node features on each edge, -applying the message function `fmsg`, and returning an aggregated message ``\\bar{\\mathbf{m}}`` -(depending on the return value of `fmsg`, an array or a named tuple of -arrays with last dimension's size `g.num_nodes`). - -If also a [`GNNLayer`](@ref) `layer` is provided, it will be passed to `fmsg` -as a first argument. - -It can be decomposed in two steps: - -```julia -m = apply_edges(fmsg, g, xi, xj, e) -m̄ = aggregate_neighbors(g, aggr, m) -``` - -GNN layers typically call `propagate` in their forward pass, -providing as input `f` a closure. - -# Arguments - -- `g`: A `GNNGraph`. -- `xi`: An array or a named tuple containing arrays whose last dimension's size - is `g.num_nodes`. It will be appropriately materialized on the - target node of each edge (see also [`edge_index`](@ref)). -- `xj`: As `xj`, but to be materialized on edges' sources. -- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. -- `fmsg`: A generic function that will be passed over to [`apply_edges`](@ref). - Has to take as inputs the edge-materialized `xi`, `xj`, and `e` - (arrays or named tuples of arrays whose last dimension' size is the size of - a batch of edges). Its output has to be an array or a named tuple of arrays - with the same batch size. If also `layer` is passed to propagate, - the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` - instead of `fmsg(xi, xj, e)`. -- `layer`: A [`GNNLayer`](@ref). If provided it will be passed to `fmsg` as a first argument. -- `aggr`: Neighborhood aggregation operator. Use `+`, `mean`, `max`, or `min`. - - -# Examples - -```julia -using GraphNeuralNetworks, Flux - -struct GNNConv <: GNNLayer - W - b - σ -end - -Flux.@functor GNNConv - -function GNNConv(ch::Pair{Int,Int}, σ=identity) - in, out = ch - W = Flux.glorot_uniform(out, in) - b = zeros(Float32, out) - GNNConv(W, b, σ) -end - -function (l::GNNConv)(g::GNNGraph, x::AbstractMatrix) - message(xi, xj, e) = l.W * xj - m̄ = propagate(message, g, +, xj=x) - return l.σ.(m̄ .+ l.bias) -end - -l = GNNConv(10 => 20) -l(g, x) -``` - -See also [`apply_edges`](@ref) and [`aggregate_neighbors`](@ref). -""" -function propagate end - -function propagate(f, g::AbstractGNNGraph, aggr; xi = nothing, xj = nothing, e = nothing) - propagate(f, g, aggr, xi, xj, e) -end - -function propagate(f, g::AbstractGNNGraph, aggr, xi, xj, e = nothing) - m = apply_edges(f, g, xi, xj, e) - m̄ = aggregate_neighbors(g, aggr, m) - return m̄ -end - -## convenience methods for working around performance issues -# https://github.com/JuliaLang/julia/issues/15276 -## and zygote issues -# https://github.com/FluxML/Zygote.jl/issues/1317 -function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer; xi = nothing, xj = nothing, - e = nothing) - propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e) -end -function propagate(f, g::AbstractGNNGraph, aggr, l::GNNLayer, xi, xj, e = nothing) - propagate((xi, xj, e) -> f(l, xi, xj, e), g, aggr, xi, xj, e) -end - -## APPLY EDGES - -""" - apply_edges(fmsg, g, [layer]; [xi, xj, e]) - apply_edges(fmsg, g, [layer,] xi, xj, e=nothing) - -Returns the message from node `j` to node `i` applying -the message function `fmsg` on the edges in graph `g`. -In the message-passing scheme, the incoming messages -from the neighborhood of `i` will later be aggregated -in order to update the features of node `i` (see [`aggregate_neighbors`](@ref)). - -The function `fmsg` operates on batches of edges, therefore -`xi`, `xj`, and `e` are tensors whose last dimension -is the batch size, or can be named tuples of -such tensors. - -If also a [`GNNLayer`](@ref) `layer` is provided, it will be passed to `fmsg` -as a first argument. - -# Arguments - -- `g`: An `AbstractGNNGraph`. -- `xi`: An array or a named tuple containing arrays whose last dimension's size - is `g.num_nodes`. It will be appropriately materialized on the - target node of each edge (see also [`edge_index`](@ref)). -- `xj`: As `xi`, but now to be materialized on each edge's source node. -- `e`: An array or a named tuple containing arrays whose last dimension's size is `g.num_edges`. -- `fmsg`: A function that takes as inputs the edge-materialized `xi`, `xj`, and `e`. - These are arrays (or named tuples of arrays) whose last dimension' size is the size of - a batch of edges. The output of `f` has to be an array (or a named tuple of arrays) - with the same batch size. If also `layer` is passed to propagate, - the signature of `fmsg` has to be `fmsg(layer, xi, xj, e)` - instead of `fmsg(xi, xj, e)`. -- `layer`: A [`GNNLayer`](@ref). If provided it will be passed to `fmsg` as a first argument. - -See also [`propagate`](@ref) and [`aggregate_neighbors`](@ref). -""" -function apply_edges end - -function apply_edges(f, g::AbstractGNNGraph; xi = nothing, xj = nothing, e = nothing) - apply_edges(f, g, xi, xj, e) -end - -function apply_edges(f, g::AbstractGNNGraph, xi, xj, e = nothing) - check_num_nodes(g, (xj, xi)) - check_num_edges(g, e) - s, t = edge_index(g) # for heterographs, errors if more than one edge type - xi = GNNGraphs._gather(xi, t) # size: (D, num_nodes) -> (D, num_edges) - xj = GNNGraphs._gather(xj, s) - m = f(xi, xj, e) - return m -end - - -## convenience methods for working around performance issues -# https://github.com/JuliaLang/julia/issues/15276 -## and zygote issues -# https://github.com/FluxML/Zygote.jl/issues/1317 -function apply_edges(f, g::AbstractGNNGraph, l::GNNLayer; xi = nothing, xj = nothing, e = nothing) - apply_edges((xi, xj, e) -> f(l, xi, xj, e), g, xi, xj, e) -end - -function apply_edges(f, g::AbstractGNNGraph, l::GNNLayer, xi, xj, e = nothing) - apply_edges((xi, xj, e) -> f(l, xi, xj, e), g, xi, xj, e) -end - -## AGGREGATE NEIGHBORS -@doc raw""" - aggregate_neighbors(g, aggr, m) - -Given a graph `g`, edge features `m`, and an aggregation -operator `aggr` (e.g `+, min, max, mean`), returns the new node -features -```math -\mathbf{x}_i = \square_{j \in \mathcal{N}(i)} \mathbf{m}_{j\to i} -``` - -Neighborhood aggregation is the second step of [`propagate`](@ref), -where it comes after [`apply_edges`](@ref). -""" -function aggregate_neighbors(g::GNNGraph, aggr, m) - check_num_edges(g, m) - s, t = edge_index(g) - return GNNGraphs._scatter(aggr, m, t, g.num_nodes) -end - -function aggregate_neighbors(g::GNNHeteroGraph, aggr, m) - check_num_edges(g, m) - s, t = edge_index(g) - dest_node_t = only(g.etypes)[3] - return GNNGraphs._scatter(aggr, m, t, g.num_nodes[dest_node_t]) -end - -### MESSAGE FUNCTIONS ### -""" - copy_xj(xi, xj, e) = xj -""" -copy_xj(xi, xj, e) = xj - -""" - copy_xi(xi, xj, e) = xi -""" -copy_xi(xi, xj, e) = xi - -""" - xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims=1) -""" -xi_dot_xj(xi, xj, e) = sum(xi .* xj, dims = 1) - -""" - xi_sub_xj(xi, xj, e) = xi .- xj -""" -xi_sub_xj(xi, xj, e) = xi .- xj - -""" - xj_sub_xi(xi, xj, e) = xj .- xi -""" -xj_sub_xi(xi, xj, e) = xj .- xi - -""" - e_mul_xj(xi, xj, e) = reshape(e, (...)) .* xj - -Reshape `e` into broadcast compatible shape with `xj` -(by prepending singleton dimensions) then perform -broadcasted multiplication. -""" -function e_mul_xj(xi, xj::AbstractArray{Tj, Nj}, - e::AbstractArray{Te, Ne}) where {Tj, Te, Nj, Ne} - @assert Ne <= Nj - e = reshape(e, ntuple(_ -> 1, Nj - Ne)..., size(e)...) - return e .* xj -end - -""" - w_mul_xj(xi, xj, w) = reshape(w, (...)) .* xj - -Similar to [`e_mul_xj`](@ref) but specialized on scalar edge features (weights). -""" -w_mul_xj(xi, xj::AbstractArray, w::Nothing) = xj # same as copy_xj if no weights - -function w_mul_xj(xi, xj::AbstractArray{Tj, Nj}, w::AbstractVector) where {Tj, Nj} - w = reshape(w, ntuple(_ -> 1, Nj - 1)..., length(w)) - return w .* xj -end - -###### PROPAGATE SPECIALIZATIONS #################### -## See also the methods defined in the package extensions. - -## COPY_XJ - -function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, e) - A = adjacency_matrix(g, weighted = false) - return xj * A -end - -## E_MUL_XJ - -# for weighted convolution -function propagate(::typeof(e_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, - e::AbstractVector) - g = set_edge_weight(g, e) - A = adjacency_matrix(g, weighted = true) - return xj * A -end - - -## W_MUL_XJ - -# for weighted convolution -function propagate(::typeof(w_mul_xj), g::GNNGraph, ::typeof(+), xi, xj::AbstractMatrix, - e::Nothing) - A = adjacency_matrix(g, weighted = true) - return xj * A -end - - -# function propagate(::typeof(copy_xj), g::GNNGraph, ::typeof(mean), xi, xj::AbstractMatrix, e) -# A = adjacency_matrix(g, weighted=false) -# D = compute_degree(A) -# return xj * A * D -# end - -# # Zygote bug. Error with sparse matrix without nograd -# compute_degree(A) = Diagonal(1f0 ./ vec(sum(A; dims=2))) - -# Flux.Zygote.@nograd compute_degree diff --git a/src/utils.jl b/src/utils.jl deleted file mode 100644 index 8434c63c8..000000000 --- a/src/utils.jl +++ /dev/null @@ -1,126 +0,0 @@ -ofeltype(x, y) = convert(float(eltype(x)), y) - -""" - reduce_nodes(aggr, g, x) - -For a batched graph `g`, return the graph-wise aggregation of the node -features `x`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. -The returned array will have last dimension `g.num_graphs`. - -See also: [`reduce_edges`](@ref). -""" -function reduce_nodes(aggr, g::GNNGraph, x) - @assert size(x)[end] == g.num_nodes - indexes = graph_indicator(g) - return NNlib.scatter(aggr, x, indexes) -end - -""" - reduce_nodes(aggr, indicator::AbstractVector, x) - -Return the graph-wise aggregation of the node features `x` given the -graph indicator `indicator`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. - -See also [`graph_indicator`](@ref). -""" -function reduce_nodes(aggr, indicator::AbstractVector, x) - return NNlib.scatter(aggr, x, indicator) -end - -""" - reduce_edges(aggr, g, e) - -For a batched graph `g`, return the graph-wise aggregation of the edge -features `e`. The aggregation operator `aggr` can be `+`, `mean`, `max`, or `min`. -The returned array will have last dimension `g.num_graphs`. -""" -function reduce_edges(aggr, g::GNNGraph, e) - @assert size(e)[end] == g.num_edges - s, t = edge_index(g) - indexes = graph_indicator(g)[s] - return NNlib.scatter(aggr, e, indexes) -end - -""" - softmax_nodes(g, x) - -Graph-wise softmax of the node features `x`. -""" -function softmax_nodes(g::GNNGraph, x) - @assert size(x)[end] == g.num_nodes - gi = graph_indicator(g) - max_ = gather(scatter(max, x, gi), gi) - num = exp.(x .- max_) - den = reduce_nodes(+, g, num) - den = gather(den, gi) - return num ./ den -end - -""" - softmax_edges(g, e) - -Graph-wise softmax of the edge features `e`. -""" -function softmax_edges(g::GNNGraph, e) - @assert size(e)[end] == g.num_edges - gi = graph_indicator(g, edges = true) - max_ = gather(scatter(max, e, gi), gi) - num = exp.(e .- max_) - den = reduce_edges(+, g, num) - den = gather(den, gi) - return num ./ (den .+ eps(eltype(e))) -end - -@doc raw""" - softmax_edge_neighbors(g, e) - -Softmax over each node's neighborhood of the edge features `e`. - -```math -\mathbf{e}'_{j\to i} = \frac{e^{\mathbf{e}_{j\to i}}} - {\sum_{j'\in N(i)} e^{\mathbf{e}_{j'\to i}}}. -``` -""" -function softmax_edge_neighbors(g::AbstractGNNGraph, e) - if g isa GNNHeteroGraph - for (key, value) in g.num_edges - @assert size(e)[end] == value - end - else - @assert size(e)[end] == g.num_edges - end - s, t = edge_index(g) - max_ = gather(scatter(max, e, t), t) - num = exp.(e .- max_) - den = gather(scatter(+, num, t), t) - return num ./ den -end - -""" - broadcast_nodes(g, x) - -Graph-wise broadcast array `x` of size `(*, g.num_graphs)` -to size `(*, g.num_nodes)`. -""" -function broadcast_nodes(g::GNNGraph, x) - @assert size(x)[end] == g.num_graphs - gi = graph_indicator(g) - return gather(x, gi) -end - -""" - broadcast_edges(g, x) - -Graph-wise broadcast array `x` of size `(*, g.num_graphs)` -to size `(*, g.num_edges)`. -""" -function broadcast_edges(g::GNNGraph, x) - @assert size(x)[end] == g.num_graphs - gi = graph_indicator(g, edges = true) - return gather(x, gi) -end - - -expand_srcdst(g::AbstractGNNGraph, x) = throw(ArgumentError("Invalid input type, expected matrix or tuple of matrices.")) -expand_srcdst(g::AbstractGNNGraph, x::AbstractMatrix) = (x, x) -expand_srcdst(g::AbstractGNNGraph, x::Tuple{<:AbstractMatrix, <:AbstractMatrix}) = x diff --git a/test/layers/conv.jl b/test/layers/conv.jl index f6846010a..224b98697 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -54,7 +54,7 @@ test_graphs = [g1, g_single_vertex] y = l(g, x) @test y[1, 1] ≈ w[1] / √(d[1] * d[2]) + w[2] / √(d[1] * d[3]) @test y[1, 2] ≈ w[3] / √(d[2] * d[1]) + w[4] / √(d[2] * d[3]) - @test y ≈ l(g, x, w, custom_norm_fn) # checking without custom + @test y ≈ l(g, x, w; norm_fn = custom_norm_fn) # checking without custom # test gradient with respect to edge weights w = rand(T, 6)