Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added TAGConv layer #430

Merged
merged 5 commits into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ export
ResGatedGraphConv,
SAGEConv,
SGConv,
TAGConv,
TransformerConv,

# layers/heteroconv
Expand Down
126 changes: 126 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,132 @@ function Base.show(io::IO, l::SGConv)
print(io, ")")
end

@doc raw"""
TAGConv(in => out, k=3; bias=true, init=glorot_uniform, add_self_loops=true, use_edge_weight=false)

TAGConv layer from "Topology Adaptive Graph Convolutional Networks" (https://arxiv.org/pdf/1710.10370.pdf).
rbSparky marked this conversation as resolved.
Show resolved Hide resolved
This layer extends the idea of graph convolutions by applying filters that adapt to the topology of the data.
It performs the operation:

```math
H^{K} = {\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\Theta}_{k}
```

where `A` is the adjacency matrix of the graph, `D` is the degree matrix, `X` is the input feature matrix, and ``{\Theta}_{k}`` is a unique weight matrix for each hop `k`.

# Arguments
- `in`: Number of input features.
- `out`: Number of output features.
- `k`: Maximum number of hops to consider. Default is `3`.
- `bias`: Whether to include a learnable bias term. Default is `true`.
- `init`: Initialization function for the weights. Default is `glorot_uniform`.
- `add_self_loops`: Whether to add self-loops to the adjacency matrix. Default is `true`.
- `use_edge_weight`: If `true`, edge weights are considered in the computation (if available). Default is `false`.

# Examples

```julia
# Example graph data
s = [1, 1, 2, 3]
t = [2, 3, 1, 1]
g = GNNGraph(s, t) # Create a graph
x = randn(Float32, 3, g.num_nodes) # Random features for each node

# Create a TAGConv layer
l = TAGConv(3 => 5, k=3; add_self_loops=true)

# Apply the TAGConv layer
y = l(g, x) # Output size: 5 × num_nodes
```
"""
struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer
weight::A
bias::B
k::Int
add_self_loops::Bool
use_edge_weight::Bool
end


@functor TAGConv

function TAGConv(ch::Pair{Int, Int}, k = 3;
init = glorot_uniform,
bias::Bool = true,
add_self_loops = true,
use_edge_weight = false)
in, out = ch
W = init(out, in)
b = bias ? Flux.create_bias(W, true, out) : false
TAGConv(W, b, k, add_self_loops, use_edge_weight)
end

function (l::TAGConv)(g::GNNGraph, x::AbstractMatrix{T},
edge_weight::EW = nothing) where
{T, EW <: Union{Nothing, AbstractVector}}
@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"

if edge_weight !== nothing
@assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"
end

if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
@assert length(edge_weight) == g.num_edges
end
end
Dout, Din = size(l.weight)
if edge_weight !== nothing
d = degree(g, T; dir = :in, edge_weight)
else
d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight)
end
c = 1 ./ sqrt.(d)

sum_pow = 0
sum_total = 0
for iter in 1:(l.k)
x = x .* c'
if edge_weight !== nothing
x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight)
elseif l.use_edge_weight
x = propagate(w_mul_xj, g, +, xj = x)
else
x = propagate(copy_xj, g, +, xj = x)
end
x = x .* c'

# On the first iteration, initialize sum_pow with the first propagated features
# On subsequent iterations, accumulate propagated features
if iter == 1
sum_pow = x
sum_total = l.weight * sum_pow
rbSparky marked this conversation as resolved.
Show resolved Hide resolved
else
sum_pow += x
# Weighted sum of features for each power of adjacency matrix
# This applies the weight matrix to the accumulated sum of propagated features
sum_total += l.weight * sum_pow
end
end

return (sum_total .+ l.bias)
end

function (l::TAGConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
edge_weight::AbstractVector)
g = GNNGraph(edge_index(g)...; g.num_nodes)
return l(g, x, edge_weight)
end

function Base.show(io::IO, l::TAGConv)
out, in = size(l.weight)
print(io, "TAGConv($in => $out")
l.k == 1 || print(io, ", ", l.k)
print(io, ")")
end

@doc raw"""
EGNNConv((in, ein) => out; hidden_size=2in, residual=false)
EGNNConv(in => out; hidden_size=2in, residual=false)
Expand Down
15 changes: 15 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,21 @@ end
end
end

@testset "TAGConv" begin
K = [1, 2, 3]
for k in K
l = TAGConv(in_channel => out_channel, k, add_self_loops = true)
for g in test_graphs
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
end

l = TAGConv(in_channel => out_channel, k, add_self_loops = true)
for g in test_graphs
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
end
end
end

@testset "EGNNConv" begin
hin = 5
hout = 5
Expand Down
Loading