From 0ae242c47f81ccf923c846a18d8ed0fba70270da Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Sun, 21 Jul 2024 15:56:02 +0200 Subject: [PATCH 1/5] Add `DCGRU` code --- src/layers/temporalconv.jl | 45 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 23df990aa..e38a1683c 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -401,6 +401,51 @@ Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0) _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x) _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g) +struct DCGRUCell + in::Int + out::Int + state0 + K::Int + dconv_u::DConv + dconv_r::DConv + dconv_c::DConv +end + +Flux.@functor DCGRUCell + +function DCGRUCell(ch::Pair{Int,Int}, K::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + in, out = ch + dconv_u = DConv((in + out) => out, K; bias=bias, init=init) + dconv_r = DConv((in + out) => out, K; bias=bias, init=init) + dconv_c = DConv((in + out) => out, K; bias=bias, init=init) + state0 = init_state(out, n) + return DCGRUCell(in, out, state0, K, dconv_u, dconv_r, dconv_c) +end + +function (dcgru::DCGRUCell)(h, g::GNNGraph, x) + h̃ = vcat(x, h) + z = dcgru.dconv_u(g, h̃) + z = Flux.sigmoid_fast.(z) + r = dcgru.dconv_r(g, h̃) + r = Flux.sigmoid_fast.(r) + ĥ = vcat(x, h .* r) + c = dcgru.dconv_c(g, ĥ) + c = Flux.tanh.(c) + h = z.* h + (1 .- z) .* c + return h, h +end + +function Base.show(io::IO, dcgru::DCGRUCell) + print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.K))") +end + +DCGRU(ch, K, n; kwargs...) = Flux.Recur(DCGRUCell(ch, K, n; kwargs...)) +Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) + +(l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) +_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) + function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) return l.(tg.snapshots, x) end From 765f9c1652126dcada3aaeeafc6157a0b4325502 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Sun, 21 Jul 2024 15:57:47 +0200 Subject: [PATCH 2/5] Add `DCGRU` tests --- test/layers/temporalconv.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/layers/temporalconv.jl b/test/layers/temporalconv.jl index 2bb7859f6..b55aff808 100644 --- a/test/layers/temporalconv.jl +++ b/test/layers/temporalconv.jl @@ -61,6 +61,14 @@ end @test model(g1) isa GNNGraph end +@testset "DCGRU" begin + dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes) + @test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N) + model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) + @test size(model(g1, g1.ndata.x)) == (1, N) + @test model(g1) isa GNNGraph +end + @testset "GINConv" begin ginconv = GINConv(Dense(in_channel => out_channel),0.3) @test length(ginconv(tg, tg.ndata.x)) == S From e0c0cc2f68a45597fb06af12befec85db6a5f5bc Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Sun, 21 Jul 2024 15:57:58 +0200 Subject: [PATCH 3/5] Add export --- src/GraphNeuralNetworks.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index 7b4a800aa..66541edc1 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -77,6 +77,7 @@ export A3TGCN, GConvLSTM, GConvGRU, + DCGRU, # layers/pool GlobalPool, From 8036671ccec2ceb01c77b6c6ec2d021e5728f95b Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Mon, 22 Jul 2024 09:59:30 +0200 Subject: [PATCH 4/5] Add docs --- src/layers/temporalconv.jl | 54 ++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index e38a1683c..eb9c925ac 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -405,7 +405,7 @@ struct DCGRUCell in::Int out::Int state0 - K::Int + k::Int dconv_u::DConv dconv_r::DConv dconv_c::DConv @@ -413,13 +413,13 @@ end Flux.@functor DCGRUCell -function DCGRUCell(ch::Pair{Int,Int}, K::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) +function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) in, out = ch - dconv_u = DConv((in + out) => out, K; bias=bias, init=init) - dconv_r = DConv((in + out) => out, K; bias=bias, init=init) - dconv_c = DConv((in + out) => out, K; bias=bias, init=init) + dconv_u = DConv((in + out) => out, k; bias=bias, init=init) + dconv_r = DConv((in + out) => out, k; bias=bias, init=init) + dconv_c = DConv((in + out) => out, k; bias=bias, init=init) state0 = init_state(out, n) - return DCGRUCell(in, out, state0, K, dconv_u, dconv_r, dconv_c) + return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c) end function (dcgru::DCGRUCell)(h, g::GNNGraph, x) @@ -436,10 +436,48 @@ function (dcgru::DCGRUCell)(h, g::GNNGraph, x) end function Base.show(io::IO, dcgru::DCGRUCell) - print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.K))") + print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))") end -DCGRU(ch, K, n; kwargs...) = Flux.Recur(DCGRUCell(ch, K, n; kwargs...)) +""" + DCGRU(in => out, k, n; [bias, init, init_state]) + +Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural +Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). + +Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Diffusion step. +- `n`: Number of nodes in the graph. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes); + +julia> y = dcgru(g1, x1); + +julia> size(y) +(5, 5) + +julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +julia> z = dcgru(g2, x2); + +julia> size(z) +(5, 5, 30) +``` +""" +DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...)) Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) (l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) From 7732eadb61e45100c6eb73e6fc71acfd66080c49 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 25 Jul 2024 16:12:03 +0200 Subject: [PATCH 5/5] Update src/layers/temporalconv.jl --- src/layers/temporalconv.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index eb9c925ac..44688cea4 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -425,12 +425,12 @@ end function (dcgru::DCGRUCell)(h, g::GNNGraph, x) h̃ = vcat(x, h) z = dcgru.dconv_u(g, h̃) - z = Flux.sigmoid_fast.(z) + z = NNlib.sigmoid_fast.(z) r = dcgru.dconv_r(g, h̃) - r = Flux.sigmoid_fast.(r) + r = NNlib.sigmoid_fast.(r) ĥ = vcat(x, h .* r) c = dcgru.dconv_c(g, ĥ) - c = Flux.tanh.(c) + c = tanh.(c) h = z.* h + (1 .- z) .* c return h, h end