diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index ae4440f22..3da1bcb85 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -65,6 +65,7 @@ export ResGatedGraphConv, SAGEConv, SGConv, + TAGConv, TransformerConv, # layers/heteroconv diff --git a/src/layers/conv.jl b/src/layers/conv.jl index a098b2595..ee3ae81d3 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1590,6 +1590,132 @@ function Base.show(io::IO, l::SGConv) print(io, ")") end +@doc raw""" + TAGConv(in => out, k=3; bias=true, init=glorot_uniform, add_self_loops=true, use_edge_weight=false) + +TAGConv layer from [Topology Adaptive Graph Convolutional Networks](https://arxiv.org/pdf/1710.10370.pdf). +This layer extends the idea of graph convolutions by applying filters that adapt to the topology of the data. +It performs the operation: + +```math +H^{K} = {\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\Theta}_{k} +``` + +where `A` is the adjacency matrix of the graph, `D` is the degree matrix, `X` is the input feature matrix, and ``{\Theta}_{k}`` is a unique weight matrix for each hop `k`. + +# Arguments +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Maximum number of hops to consider. Default is `3`. +- `bias`: Whether to include a learnable bias term. Default is `true`. +- `init`: Initialization function for the weights. Default is `glorot_uniform`. +- `add_self_loops`: Whether to add self-loops to the adjacency matrix. Default is `true`. +- `use_edge_weight`: If `true`, edge weights are considered in the computation (if available). Default is `false`. + +# Examples + +```julia +# Example graph data +s = [1, 1, 2, 3] +t = [2, 3, 1, 1] +g = GNNGraph(s, t) # Create a graph +x = randn(Float32, 3, g.num_nodes) # Random features for each node + +# Create a TAGConv layer +l = TAGConv(3 => 5, k=3; add_self_loops=true) + +# Apply the TAGConv layer +y = l(g, x) # Output size: 5 × num_nodes +``` +""" +struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer + weight::A + bias::B + k::Int + add_self_loops::Bool + use_edge_weight::Bool +end + + +@functor TAGConv + +function TAGConv(ch::Pair{Int, Int}, k = 3; + init = glorot_uniform, + bias::Bool = true, + add_self_loops = true, + use_edge_weight = false) + in, out = ch + W = init(out, in) + b = bias ? Flux.create_bias(W, true, out) : false + TAGConv(W, b, k, add_self_loops, use_edge_weight) +end + +function (l::TAGConv)(g::GNNGraph, x::AbstractMatrix{T}, + edge_weight::EW = nothing) where + {T, EW <: Union{Nothing, AbstractVector}} + @assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs" + + if edge_weight !== nothing + @assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))" + end + + if l.add_self_loops + g = add_self_loops(g) + if edge_weight !== nothing + edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)] + @assert length(edge_weight) == g.num_edges + end + end + Dout, Din = size(l.weight) + if edge_weight !== nothing + d = degree(g, T; dir = :in, edge_weight) + else + d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight) + end + c = 1 ./ sqrt.(d) + + sum_pow = 0 + sum_total = 0 + for iter in 1:(l.k) + x = x .* c' + if edge_weight !== nothing + x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight) + elseif l.use_edge_weight + x = propagate(w_mul_xj, g, +, xj = x) + else + x = propagate(copy_xj, g, +, xj = x) + end + x = x .* c' + + # On the first iteration, initialize sum_pow with the first propagated features + # On subsequent iterations, accumulate propagated features + if iter == 1 + sum_pow = x + sum_total = l.weight * sum_pow + else + sum_pow += x + # Weighted sum of features for each power of adjacency matrix + # This applies the weight matrix to the accumulated sum of propagated features + sum_total += l.weight * sum_pow + end + end + + return (sum_total .+ l.bias) +end + +function (l::TAGConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, + edge_weight::AbstractVector) + g = GNNGraph(edge_index(g)...; g.num_nodes) + return l(g, x, edge_weight) +end + +function Base.show(io::IO, l::TAGConv) + out, in = size(l.weight) + print(io, "TAGConv($in => $out") + l.k == 1 || print(io, ", ", l.k) + print(io, ")") +end + @doc raw""" EGNNConv((in, ein) => out; hidden_size=2in, residual=false) EGNNConv(in => out; hidden_size=2in, residual=false) diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 60562b048..5ebea58b8 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -287,6 +287,21 @@ end end end +@testset "TAGConv" begin + K = [1, 2, 3] + for k in K + l = TAGConv(in_channel => out_channel, k, add_self_loops = true) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + + l = TAGConv(in_channel => out_channel, k, add_self_loops = true) + for g in test_graphs + test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes)) + end + end +end + @testset "EGNNConv" begin hin = 5 hout = 5