diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index fdab753b9..1e8e7006b 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -74,6 +74,7 @@ export # layers/temporalconv TGCN, A3TGCN, + GConvLSTM, GConvGRU, # layers/pool diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 3717374c6..23df990aa 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -279,6 +279,128 @@ Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0) _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x) _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g) +struct GConvLSTMCell <: GNNLayer + conv_x_i::ChebConv + conv_h_i::ChebConv + w_i + b_i + conv_x_f::ChebConv + conv_h_f::ChebConv + w_f + b_f + conv_x_c::ChebConv + conv_h_c::ChebConv + w_c + b_c + conv_x_o::ChebConv + conv_h_o::ChebConv + w_o + b_o + k::Int + state0 + in::Int + out::Int +end + +Flux.@functor GConvLSTMCell + +function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int; + bias::Bool = true, + init = Flux.glorot_uniform, + init_state = Flux.zeros32) + in, out = ch + # input gate + conv_x_i = ChebConv(in => out, k; bias, init) + conv_h_i = ChebConv(out => out, k; bias, init) + w_i = init(out, 1) + b_i = bias ? Flux.create_bias(w_i, true, out) : false + # forget gate + conv_x_f = ChebConv(in => out, k; bias, init) + conv_h_f = ChebConv(out => out, k; bias, init) + w_f = init(out, 1) + b_f = bias ? Flux.create_bias(w_f, true, out) : false + # cell state + conv_x_c = ChebConv(in => out, k; bias, init) + conv_h_c = ChebConv(out => out, k; bias, init) + w_c = init(out, 1) + b_c = bias ? Flux.create_bias(w_c, true, out) : false + # output gate + conv_x_o = ChebConv(in => out, k; bias, init) + conv_h_o = ChebConv(out => out, k; bias, init) + w_o = init(out, 1) + b_o = bias ? Flux.create_bias(w_o, true, out) : false + state0 = (init_state(out, n), init_state(out, n)) + return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i, + conv_x_f, conv_h_f, w_f, b_f, + conv_x_c, conv_h_c, w_c, b_c, + conv_x_o, conv_h_o, w_o, b_o, + k, state0, in, out) +end + +function (gclstm::GConvLSTMCell)((h, c), g::GNNGraph, x) + # input gate + i = gclstm.conv_x_i(g, x) .+ gclstm.conv_h_i(g, h) .+ gclstm.w_i .* c .+ gclstm.b_i + i = Flux.sigmoid_fast(i) + # forget gate + f = gclstm.conv_x_f(g, x) .+ gclstm.conv_h_f(g, h) .+ gclstm.w_f .* c .+ gclstm.b_f + f = Flux.sigmoid_fast(f) + # cell state + c = f .* c .+ i .* Flux.tanh_fast(gclstm.conv_x_c(g, x) .+ gclstm.conv_h_c(g, h) .+ gclstm.w_c .* c .+ gclstm.b_c) + # output gate + o = gclstm.conv_x_o(g, x) .+ gclstm.conv_h_o(g, h) .+ gclstm.w_o .* c .+ gclstm.b_o + o = Flux.sigmoid_fast(o) + h = o .* Flux.tanh_fast(c) + return (h,c), h +end + +function Base.show(io::IO, gclstm::GConvLSTMCell) + print(io, "GConvLSTMCell($(gclstm.in) => $(gclstm.out))") +end + +""" + GConvLSTM(in => out, k, n; [bias, init, init_state]) + +Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). + +Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Chebyshev polynomial order. +- `n`: Number of nodes in the graph. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes); + +julia> y = gclstm(g1, x1); + +julia> size(y) +(5, 5) + +julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +julia> z = gclstm(g2, x2); + +julia> size(z) +(5, 5, 30) +``` +""" +GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...)) +Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0) + +(l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x) +_applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g) + function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) return l.(tg.snapshots, x) end diff --git a/test/layers/temporalconv.jl b/test/layers/temporalconv.jl index a947a7e41..2bb7859f6 100644 --- a/test/layers/temporalconv.jl +++ b/test/layers/temporalconv.jl @@ -34,6 +34,19 @@ end @test model(g1) isa GNNGraph end +@testset "GConvLSTMCell" begin + gconvlstm = GraphNeuralNetworks.GConvLSTMCell(in_channel => out_channel, 2, g1.num_nodes) + (h, c), h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) + @test size(h) == (out_channel, N) + @test size(c) == (out_channel, N) +end + +@testset "GConvLSTM" begin + gconvlstm = GConvLSTM(in_channel => out_channel, 2, g1.num_nodes) + @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) + model = GNNChain(GConvLSTM(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) +end + @testset "GConvGRUCell" begin gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes) h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) @@ -55,7 +68,6 @@ end @test length(Flux.gradient(x ->sum(sum(ginconv(tg, x))), tg.ndata.x)[1]) == S end - @testset "ChebConv" begin chebconv = ChebConv(in_channel => out_channel, 5) @test length(chebconv(tg, tg.ndata.x)) == S