Skip to content

Commit

Permalink
Added TAGConv layer (#430)
Browse files Browse the repository at this point in the history
* tagconv first attempt

* tagconv first attempt

* fix

* fix

* Update src/layers/conv.jl

Co-authored-by: Carlo Lucibello <[email protected]>

---------

Co-authored-by: Carlo Lucibello <[email protected]>
  • Loading branch information
rbSparky and CarloLucibello authored Jun 16, 2024
1 parent c5fdf0f commit 942fe91
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ export
ResGatedGraphConv,
SAGEConv,
SGConv,
TAGConv,
TransformerConv,

# layers/heteroconv
Expand Down
126 changes: 126 additions & 0 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,132 @@ function Base.show(io::IO, l::SGConv)
print(io, ")")
end

@doc raw"""
TAGConv(in => out, k=3; bias=true, init=glorot_uniform, add_self_loops=true, use_edge_weight=false)
TAGConv layer from [Topology Adaptive Graph Convolutional Networks](https://arxiv.org/pdf/1710.10370.pdf).
This layer extends the idea of graph convolutions by applying filters that adapt to the topology of the data.
It performs the operation:
```math
H^{K} = {\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\Theta}_{k}
```
where `A` is the adjacency matrix of the graph, `D` is the degree matrix, `X` is the input feature matrix, and ``{\Theta}_{k}`` is a unique weight matrix for each hop `k`.
# Arguments
- `in`: Number of input features.
- `out`: Number of output features.
- `k`: Maximum number of hops to consider. Default is `3`.
- `bias`: Whether to include a learnable bias term. Default is `true`.
- `init`: Initialization function for the weights. Default is `glorot_uniform`.
- `add_self_loops`: Whether to add self-loops to the adjacency matrix. Default is `true`.
- `use_edge_weight`: If `true`, edge weights are considered in the computation (if available). Default is `false`.
# Examples
```julia
# Example graph data
s = [1, 1, 2, 3]
t = [2, 3, 1, 1]
g = GNNGraph(s, t) # Create a graph
x = randn(Float32, 3, g.num_nodes) # Random features for each node
# Create a TAGConv layer
l = TAGConv(3 => 5, k=3; add_self_loops=true)
# Apply the TAGConv layer
y = l(g, x) # Output size: 5 × num_nodes
```
"""
struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer
weight::A
bias::B
k::Int
add_self_loops::Bool
use_edge_weight::Bool
end


@functor TAGConv

function TAGConv(ch::Pair{Int, Int}, k = 3;
init = glorot_uniform,
bias::Bool = true,
add_self_loops = true,
use_edge_weight = false)
in, out = ch
W = init(out, in)
b = bias ? Flux.create_bias(W, true, out) : false
TAGConv(W, b, k, add_self_loops, use_edge_weight)
end

function (l::TAGConv)(g::GNNGraph, x::AbstractMatrix{T},
edge_weight::EW = nothing) where
{T, EW <: Union{Nothing, AbstractVector}}
@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"

if edge_weight !== nothing
@assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"
end

if l.add_self_loops
g = add_self_loops(g)
if edge_weight !== nothing
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
@assert length(edge_weight) == g.num_edges
end
end
Dout, Din = size(l.weight)
if edge_weight !== nothing
d = degree(g, T; dir = :in, edge_weight)
else
d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight)
end
c = 1 ./ sqrt.(d)

sum_pow = 0
sum_total = 0
for iter in 1:(l.k)
x = x .* c'
if edge_weight !== nothing
x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight)
elseif l.use_edge_weight
x = propagate(w_mul_xj, g, +, xj = x)
else
x = propagate(copy_xj, g, +, xj = x)
end
x = x .* c'

# On the first iteration, initialize sum_pow with the first propagated features
# On subsequent iterations, accumulate propagated features
if iter == 1
sum_pow = x
sum_total = l.weight * sum_pow
else
sum_pow += x
# Weighted sum of features for each power of adjacency matrix
# This applies the weight matrix to the accumulated sum of propagated features
sum_total += l.weight * sum_pow
end
end

return (sum_total .+ l.bias)
end

function (l::TAGConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
edge_weight::AbstractVector)
g = GNNGraph(edge_index(g)...; g.num_nodes)
return l(g, x, edge_weight)
end

function Base.show(io::IO, l::TAGConv)
out, in = size(l.weight)
print(io, "TAGConv($in => $out")
l.k == 1 || print(io, ", ", l.k)
print(io, ")")
end

@doc raw"""
EGNNConv((in, ein) => out; hidden_size=2in, residual=false)
EGNNConv(in => out; hidden_size=2in, residual=false)
Expand Down
15 changes: 15 additions & 0 deletions test/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,21 @@ end
end
end

@testset "TAGConv" begin
K = [1, 2, 3]
for k in K
l = TAGConv(in_channel => out_channel, k, add_self_loops = true)
for g in test_graphs
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
end

l = TAGConv(in_channel => out_channel, k, add_self_loops = true)
for g in test_graphs
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
end
end
end

@testset "EGNNConv" begin
hin = 5
hout = 5
Expand Down

0 comments on commit 942fe91

Please sign in to comment.