Skip to content

Commit

Permalink
Add EvolveGCNO temporal layer (#489)
Browse files Browse the repository at this point in the history
* First draft

* Add test

* Add export `EvolveGCNO`

* Improve `EvolveGCNO`

* Ecport `EvolveGCNo`

* Add `EvolveGCNO`

* Fix

* Add `EvolveGCNO` test

* Fix
  • Loading branch information
aurorarossi authored Sep 17, 2024
1 parent c896eda commit 2313a96
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 2 deletions.
3 changes: 2 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ export TGCN,
A3TGCN,
GConvGRU,
GConvLSTM,
DCGRU
DCGRU,
EvolveGCNO

end #module

60 changes: 60 additions & 0 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,63 @@ LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,)

DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...))

@concrete struct EvolveGCNO <: GNNLayer
in_dims::Int
out_dims::Int
use_bias::Bool
init_weight
init_state::Function
init_bias
end

function EvolveGCNO(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
in_dims, out_dims = ch
return EvolveGCNO(in_dims, out_dims, use_bias, init_weight, init_state, init_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::EvolveGCNO)
weight = l.init_weight(rng, l.out_dims, l.in_dims)
Wf = l.init_weight(rng, l.out_dims, l.in_dims)
Uf = l.init_weight(rng, l.out_dims, l.in_dims)
Wi = l.init_weight(rng, l.out_dims, l.in_dims)
Ui = l.init_weight(rng, l.out_dims, l.in_dims)
Wo = l.init_weight(rng, l.out_dims, l.in_dims)
Uo = l.init_weight(rng, l.out_dims, l.in_dims)
Wc = l.init_weight(rng, l.out_dims, l.in_dims)
Uc = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
Bf = l.init_bias(rng, l.out_dims, l.in_dims)
Bi = l.init_bias(rng, l.out_dims, l.in_dims)
Bo = l.init_bias(rng, l.out_dims, l.in_dims)
Bc = l.init_bias(rng, l.out_dims, l.in_dims)
return (; conv = (; weight, bias), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc, Bf, Bi, Bo, Bc))
else
return (; conv = (; weight), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc))
end
end

function LuxCore.initialstates(rng::AbstractRNG, l::EvolveGCNO)
h = l.init_state(rng, l.out_dims, l.in_dims)
c = l.init_state(rng, l.out_dims, l.in_dims)
return (; conv = (;), lstm = (; h, c))
end

function (l::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x, ps::NamedTuple, st::NamedTuple)
H, C = st.lstm
W = ps.conv.weight
m = (; ps.conv.weight, bias = _getbias(ps),
add_self_loops =true, use_edge_weight=true, σ = identity)

X = map(1:tg.num_snapshots) do i
F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W .+ ps.lstm.Uf .* H .+ ps.lstm.Bf)
I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W .+ ps.lstm.Ui .* H .+ ps.lstm.Bi)
O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W .+ ps.lstm.Uo .* H .+ ps.lstm.Bo)
= NNlib.tanh_fast.(ps.lstm.Wc .* W .+ ps.lstm.Uc .* H .+ ps.lstm.Bc)
C = F .* C + I .*
H = O .* NNlib.tanh_fast.(C)
W = H
GNNlib.gcn_conv(m,tg.snapshots[i], x[i], nothing, d -> 1 ./ sqrt.(d), W)
end
return X, (; conv = (;), lstm = (h = H, c = C))
end
11 changes: 11 additions & 0 deletions GNNLux/test/layers/temporalconv_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
g = rand_graph(rng, 10, 40)
x = randn(rng, Float32, 3, 10)

tg = TemporalSnapshotsGNNGraph([g for _ in 1:5])
tx = [x for _ in 1:5]

@testset "TGCN" begin
l = TGCN(3=>3)
ps = LuxCore.initialparameters(rng, l)
Expand Down Expand Up @@ -44,4 +47,12 @@
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "EvolveGCNO" begin
l = EvolveGCNO(3=>3)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (tx, ps) -> sum(sum(first(l(tg, tx, ps, st))))
test_gradients(loss, tx, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end
end
3 changes: 2 additions & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ export TGCN,
A3TGCN,
GConvLSTM,
GConvGRU,
DCGRU
DCGRU,
EvolveGCNO

include("layers/pool.jl")
export GlobalPool,
Expand Down
97 changes: 97 additions & 0 deletions src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,103 @@ Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)

"""
EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191).
Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph.
# Arguments
- `in`: Number of input features.
- `out`: Number of output features.
- `bias`: Add learnable bias. Default `true`.
- `init`: Weights' initializer. Default `glorot_uniform`.
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
# Examples
```jldoctest
julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))])
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10]
num_edges: [20, 14, 22]
num_snapshots: 3
julia> ev = EvolveGCNO(4 => 5)
EvolveGCNO(4 => 5)
julia> size(ev(tg, tg.ndata.x))
(3,)
julia> size(ev(tg, tg.ndata.x)[1])
(5, 10)
```
"""
struct EvolveGCNO
conv
W_init
init_state
in::Int
out::Int
Wf
Uf
Bf
Wi
Ui
Bi
Wo
Uo
Bo
Wc
Uc
Bc
end

Flux.@functor EvolveGCNO

function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
in, out = ch
W = init(out, in)
conv = GCNConv(ch; bias = bias, init = init)
Wf = init(out, in)
Uf = init(out, in)
Bf = bias ? init(out, in) : nothing
Wi = init(out, in)
Ui = init(out, in)
Bi = bias ? init(out, in) : nothing
Wo = init(out, in)
Uo = init(out, in)
Bo = bias ? init(out, in) : nothing
Wc = init(out, in)
Uc = init(out, in)
Bc = bias ? init(out, in) : nothing
return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc)
end

function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x)
H = egcno.init_state(egcno.out, egcno.in)
C = egcno.init_state(egcno.out, egcno.in)
W = egcno.W_init
X = map(1:tg.num_snapshots) do i
F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf)
I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi)
O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo)
= Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc)
C = F .* C + I .*
H = O .* tanh_fast.(C)
W = H
egcno.conv(tg.snapshots[i], x[i]; conv_weight = H)
end
return X
end

function Base.show(io::IO, egcno::EvolveGCNO)
print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))")
end

function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
return l.(tg.snapshots, x)
end
Expand Down
6 changes: 6 additions & 0 deletions test/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ end
@test model(g1) isa GNNGraph
end

@testset "EvolveGCNO" begin
evolvegcno = EvolveGCNO(in_channel => out_channel)
@test length(Flux.gradient(x -> sum(sum(evolvegcno(tg, x))), tg.ndata.x)[1]) == S
@test size(evolvegcno(tg, tg.ndata.x)[1]) == (out_channel, N)
end

@testset "GINConv" begin
ginconv = GINConv(Dense(in_channel => out_channel),0.3)
@test length(ginconv(tg, tg.ndata.x)) == S
Expand Down

0 comments on commit 2313a96

Please sign in to comment.