Skip to content

Commit 942fe91

Browse files
Added TAGConv layer (#430)
* tagconv first attempt * tagconv first attempt * fix * fix * Update src/layers/conv.jl Co-authored-by: Carlo Lucibello <[email protected]> --------- Co-authored-by: Carlo Lucibello <[email protected]>
1 parent c5fdf0f commit 942fe91

File tree

3 files changed

+142
-0
lines changed

3 files changed

+142
-0
lines changed

src/GraphNeuralNetworks.jl

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ export
6565
ResGatedGraphConv,
6666
SAGEConv,
6767
SGConv,
68+
TAGConv,
6869
TransformerConv,
6970

7071
# layers/heteroconv

src/layers/conv.jl

+126
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,132 @@ function Base.show(io::IO, l::SGConv)
15901590
print(io, ")")
15911591
end
15921592

1593+
@doc raw"""
1594+
TAGConv(in => out, k=3; bias=true, init=glorot_uniform, add_self_loops=true, use_edge_weight=false)
1595+
1596+
TAGConv layer from [Topology Adaptive Graph Convolutional Networks](https://arxiv.org/pdf/1710.10370.pdf).
1597+
This layer extends the idea of graph convolutions by applying filters that adapt to the topology of the data.
1598+
It performs the operation:
1599+
1600+
```math
1601+
H^{K} = {\sum}_{k=0}^K (D^{-1/2} A D^{-1/2})^{k} X {\Theta}_{k}
1602+
```
1603+
1604+
where `A` is the adjacency matrix of the graph, `D` is the degree matrix, `X` is the input feature matrix, and ``{\Theta}_{k}`` is a unique weight matrix for each hop `k`.
1605+
1606+
# Arguments
1607+
- `in`: Number of input features.
1608+
- `out`: Number of output features.
1609+
- `k`: Maximum number of hops to consider. Default is `3`.
1610+
- `bias`: Whether to include a learnable bias term. Default is `true`.
1611+
- `init`: Initialization function for the weights. Default is `glorot_uniform`.
1612+
- `add_self_loops`: Whether to add self-loops to the adjacency matrix. Default is `true`.
1613+
- `use_edge_weight`: If `true`, edge weights are considered in the computation (if available). Default is `false`.
1614+
1615+
# Examples
1616+
1617+
```julia
1618+
# Example graph data
1619+
s = [1, 1, 2, 3]
1620+
t = [2, 3, 1, 1]
1621+
g = GNNGraph(s, t) # Create a graph
1622+
x = randn(Float32, 3, g.num_nodes) # Random features for each node
1623+
1624+
# Create a TAGConv layer
1625+
l = TAGConv(3 => 5, k=3; add_self_loops=true)
1626+
1627+
# Apply the TAGConv layer
1628+
y = l(g, x) # Output size: 5 × num_nodes
1629+
```
1630+
"""
1631+
struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer
1632+
weight::A
1633+
bias::B
1634+
k::Int
1635+
add_self_loops::Bool
1636+
use_edge_weight::Bool
1637+
end
1638+
1639+
1640+
@functor TAGConv
1641+
1642+
function TAGConv(ch::Pair{Int, Int}, k = 3;
1643+
init = glorot_uniform,
1644+
bias::Bool = true,
1645+
add_self_loops = true,
1646+
use_edge_weight = false)
1647+
in, out = ch
1648+
W = init(out, in)
1649+
b = bias ? Flux.create_bias(W, true, out) : false
1650+
TAGConv(W, b, k, add_self_loops, use_edge_weight)
1651+
end
1652+
1653+
function (l::TAGConv)(g::GNNGraph, x::AbstractMatrix{T},
1654+
edge_weight::EW = nothing) where
1655+
{T, EW <: Union{Nothing, AbstractVector}}
1656+
@assert !(g isa GNNGraph{<:ADJMAT_T} && edge_weight !== nothing) "Providing external edge_weight is not yet supported for adjacency matrix graphs"
1657+
1658+
if edge_weight !== nothing
1659+
@assert length(edge_weight)==g.num_edges "Wrong number of edge weights (expected $(g.num_edges) but given $(length(edge_weight)))"
1660+
end
1661+
1662+
if l.add_self_loops
1663+
g = add_self_loops(g)
1664+
if edge_weight !== nothing
1665+
edge_weight = [edge_weight; fill!(similar(edge_weight, g.num_nodes), 1)]
1666+
@assert length(edge_weight) == g.num_edges
1667+
end
1668+
end
1669+
Dout, Din = size(l.weight)
1670+
if edge_weight !== nothing
1671+
d = degree(g, T; dir = :in, edge_weight)
1672+
else
1673+
d = degree(g, T; dir = :in, edge_weight=l.use_edge_weight)
1674+
end
1675+
c = 1 ./ sqrt.(d)
1676+
1677+
sum_pow = 0
1678+
sum_total = 0
1679+
for iter in 1:(l.k)
1680+
x = x .* c'
1681+
if edge_weight !== nothing
1682+
x = propagate(e_mul_xj, g, +, xj = x, e = edge_weight)
1683+
elseif l.use_edge_weight
1684+
x = propagate(w_mul_xj, g, +, xj = x)
1685+
else
1686+
x = propagate(copy_xj, g, +, xj = x)
1687+
end
1688+
x = x .* c'
1689+
1690+
# On the first iteration, initialize sum_pow with the first propagated features
1691+
# On subsequent iterations, accumulate propagated features
1692+
if iter == 1
1693+
sum_pow = x
1694+
sum_total = l.weight * sum_pow
1695+
else
1696+
sum_pow += x
1697+
# Weighted sum of features for each power of adjacency matrix
1698+
# This applies the weight matrix to the accumulated sum of propagated features
1699+
sum_total += l.weight * sum_pow
1700+
end
1701+
end
1702+
1703+
return (sum_total .+ l.bias)
1704+
end
1705+
1706+
function (l::TAGConv)(g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix,
1707+
edge_weight::AbstractVector)
1708+
g = GNNGraph(edge_index(g)...; g.num_nodes)
1709+
return l(g, x, edge_weight)
1710+
end
1711+
1712+
function Base.show(io::IO, l::TAGConv)
1713+
out, in = size(l.weight)
1714+
print(io, "TAGConv($in => $out")
1715+
l.k == 1 || print(io, ", ", l.k)
1716+
print(io, ")")
1717+
end
1718+
15931719
@doc raw"""
15941720
EGNNConv((in, ein) => out; hidden_size=2in, residual=false)
15951721
EGNNConv(in => out; hidden_size=2in, residual=false)

test/layers/conv.jl

+15
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,21 @@ end
287287
end
288288
end
289289

290+
@testset "TAGConv" begin
291+
K = [1, 2, 3]
292+
for k in K
293+
l = TAGConv(in_channel => out_channel, k, add_self_loops = true)
294+
for g in test_graphs
295+
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
296+
end
297+
298+
l = TAGConv(in_channel => out_channel, k, add_self_loops = true)
299+
for g in test_graphs
300+
test_layer(l, g, rtol = RTOL_HIGH, outsize = (out_channel, g.num_nodes))
301+
end
302+
end
303+
end
304+
290305
@testset "EGNNConv" begin
291306
hin = 5
292307
hout = 5

0 commit comments

Comments
 (0)