From 2313a964a9e99da7b70c5072c99cdac9746fd459 Mon Sep 17 00:00:00 2001 From: Aurora Rossi <65721467+aurorarossi@users.noreply.github.com> Date: Tue, 17 Sep 2024 11:22:02 +0200 Subject: [PATCH] Add `EvolveGCNO` temporal layer (#489) * First draft * Add test * Add export `EvolveGCNO` * Improve `EvolveGCNO` * Ecport `EvolveGCNo` * Add `EvolveGCNO` * Fix * Add `EvolveGCNO` test * Fix --- GNNLux/src/GNNLux.jl | 3 +- GNNLux/src/layers/temporalconv.jl | 60 +++++++++++++++ GNNLux/test/layers/temporalconv_test.jl | 11 +++ src/GraphNeuralNetworks.jl | 3 +- src/layers/temporalconv.jl | 97 +++++++++++++++++++++++++ test/layers/temporalconv.jl | 6 ++ 6 files changed, 178 insertions(+), 2 deletions(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 8dac6eca2..4a72d8d33 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -45,7 +45,8 @@ export TGCN, A3TGCN, GConvGRU, GConvLSTM, - DCGRU + DCGRU, + EvolveGCNO end #module \ No newline at end of file diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index 63c196a55..64ce9b78a 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -274,3 +274,63 @@ LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,) DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...)) +@concrete struct EvolveGCNO <: GNNLayer + in_dims::Int + out_dims::Int + use_bias::Bool + init_weight + init_state::Function + init_bias +end + +function EvolveGCNO(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32) + in_dims, out_dims = ch + return EvolveGCNO(in_dims, out_dims, use_bias, init_weight, init_state, init_bias) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::EvolveGCNO) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + Wf = l.init_weight(rng, l.out_dims, l.in_dims) + Uf = l.init_weight(rng, l.out_dims, l.in_dims) + Wi = l.init_weight(rng, l.out_dims, l.in_dims) + Ui = l.init_weight(rng, l.out_dims, l.in_dims) + Wo = l.init_weight(rng, l.out_dims, l.in_dims) + Uo = l.init_weight(rng, l.out_dims, l.in_dims) + Wc = l.init_weight(rng, l.out_dims, l.in_dims) + Uc = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + Bf = l.init_bias(rng, l.out_dims, l.in_dims) + Bi = l.init_bias(rng, l.out_dims, l.in_dims) + Bo = l.init_bias(rng, l.out_dims, l.in_dims) + Bc = l.init_bias(rng, l.out_dims, l.in_dims) + return (; conv = (; weight, bias), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc, Bf, Bi, Bo, Bc)) + else + return (; conv = (; weight), lstm = (; Wf, Uf, Wi, Ui, Wo, Uo, Wc, Uc)) + end +end + +function LuxCore.initialstates(rng::AbstractRNG, l::EvolveGCNO) + h = l.init_state(rng, l.out_dims, l.in_dims) + c = l.init_state(rng, l.out_dims, l.in_dims) + return (; conv = (;), lstm = (; h, c)) +end + +function (l::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x, ps::NamedTuple, st::NamedTuple) + H, C = st.lstm + W = ps.conv.weight + m = (; ps.conv.weight, bias = _getbias(ps), + add_self_loops =true, use_edge_weight=true, σ = identity) + + X = map(1:tg.num_snapshots) do i + F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W .+ ps.lstm.Uf .* H .+ ps.lstm.Bf) + I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W .+ ps.lstm.Ui .* H .+ ps.lstm.Bi) + O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W .+ ps.lstm.Uo .* H .+ ps.lstm.Bo) + C̃ = NNlib.tanh_fast.(ps.lstm.Wc .* W .+ ps.lstm.Uc .* H .+ ps.lstm.Bc) + C = F .* C + I .* C̃ + H = O .* NNlib.tanh_fast.(C) + W = H + GNNlib.gcn_conv(m,tg.snapshots[i], x[i], nothing, d -> 1 ./ sqrt.(d), W) + end + return X, (; conv = (;), lstm = (h = H, c = C)) +end diff --git a/GNNLux/test/layers/temporalconv_test.jl b/GNNLux/test/layers/temporalconv_test.jl index 7a7c48f4a..ec670b6bc 100644 --- a/GNNLux/test/layers/temporalconv_test.jl +++ b/GNNLux/test/layers/temporalconv_test.jl @@ -5,6 +5,9 @@ g = rand_graph(rng, 10, 40) x = randn(rng, Float32, 3, 10) + tg = TemporalSnapshotsGNNGraph([g for _ in 1:5]) + tx = [x for _ in 1:5] + @testset "TGCN" begin l = TGCN(3=>3) ps = LuxCore.initialparameters(rng, l) @@ -44,4 +47,12 @@ loss = (x, ps) -> sum(first(l(g, x, ps, st))) test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) end + + @testset "EvolveGCNO" begin + l = EvolveGCNO(3=>3) + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + loss = (tx, ps) -> sum(sum(first(l(tg, tx, ps, st)))) + test_gradients(loss, tx, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) + end end \ No newline at end of file diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index bf6991155..cebf7b7d3 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -54,7 +54,8 @@ export TGCN, A3TGCN, GConvLSTM, GConvGRU, - DCGRU + DCGRU, + EvolveGCNO include("layers/pool.jl") export GlobalPool, diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 443ef2a3a..2f6292f28 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -484,6 +484,103 @@ Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) +""" + EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + +Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). + +Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. + + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10] + num_edges: [20, 14, 22] + num_snapshots: 3 + +julia> ev = EvolveGCNO(4 => 5) +EvolveGCNO(4 => 5) + +julia> size(ev(tg, tg.ndata.x)) +(3,) + +julia> size(ev(tg, tg.ndata.x)[1]) +(5, 10) +``` +""" +struct EvolveGCNO + conv + W_init + init_state + in::Int + out::Int + Wf + Uf + Bf + Wi + Ui + Bi + Wo + Uo + Bo + Wc + Uc + Bc +end + +Flux.@functor EvolveGCNO + +function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + in, out = ch + W = init(out, in) + conv = GCNConv(ch; bias = bias, init = init) + Wf = init(out, in) + Uf = init(out, in) + Bf = bias ? init(out, in) : nothing + Wi = init(out, in) + Ui = init(out, in) + Bi = bias ? init(out, in) : nothing + Wo = init(out, in) + Uo = init(out, in) + Bo = bias ? init(out, in) : nothing + Wc = init(out, in) + Uc = init(out, in) + Bc = bias ? init(out, in) : nothing + return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) +end + +function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) + H = egcno.init_state(egcno.out, egcno.in) + C = egcno.init_state(egcno.out, egcno.in) + W = egcno.W_init + X = map(1:tg.num_snapshots) do i + F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) + I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) + O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) + C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) + C = F .* C + I .* C̃ + H = O .* tanh_fast.(C) + W = H + egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) + end + return X +end + +function Base.show(io::IO, egcno::EvolveGCNO) + print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") +end + function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) return l.(tg.snapshots, x) end diff --git a/test/layers/temporalconv.jl b/test/layers/temporalconv.jl index 45c8acf04..bdf44b45f 100644 --- a/test/layers/temporalconv.jl +++ b/test/layers/temporalconv.jl @@ -69,6 +69,12 @@ end @test model(g1) isa GNNGraph end +@testset "EvolveGCNO" begin + evolvegcno = EvolveGCNO(in_channel => out_channel) + @test length(Flux.gradient(x -> sum(sum(evolvegcno(tg, x))), tg.ndata.x)[1]) == S + @test size(evolvegcno(tg, tg.ndata.x)[1]) == (out_channel, N) +end + @testset "GINConv" begin ginconv = GINConv(Dense(in_channel => out_channel),0.3) @test length(ginconv(tg, tg.ndata.x)) == S