From 27d13c8f85ce8617b95d8dcbdefead4fe982188d Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 17 Dec 2024 12:11:38 +0100 Subject: [PATCH] rewrite recurrent temporal layers for Flux v0.16 (#560) * GConvGRU * GConvLSTM * GNNRecurrence * EvolveGCNOCell * cleanup * EvolveGCNO * TGCNCell * TGCCN * tests * fix gatedgraphconv * fix set2set --- GNNLux/src/layers/conv.jl | 6 +- GNNlib/src/layers/conv.jl | 3 +- GNNlib/src/layers/pool.jl | 4 +- GraphNeuralNetworks/Project.toml | 4 +- .../src/GraphNeuralNetworks.jl | 15 +- GraphNeuralNetworks/src/layers/conv.jl | 6 + GraphNeuralNetworks/src/layers/pool.jl | 6 - .../src/layers/temporalconv.jl | 1381 ++++++++++------- .../test/layers/temporalconv.jl | 211 ++- GraphNeuralNetworks/test/test_module.jl | 1 + 10 files changed, 1019 insertions(+), 618 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index f92dd1ec6..63c4f90b4 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -1261,7 +1261,11 @@ LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l function (l::GatedGraphConv)(g, x, ps, st) gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) - fgru = (x, h) -> gru((x, (h,)))[1] # make the forward compatible with Flux.GRUCell style + # make the forward compatible with Flux.GRUCell style + function fgru(x, h) + y, (h, ) = gru((x, (h,))) + return y, h + end m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims) return GNNlib.gated_graph_conv(m, g, x), st end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index fe7c27d9c..bd9bd18b3 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -227,8 +227,7 @@ function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix) for i in 1:(l.num_layers) m = view(l.weight, :, :, i) * h m = propagate(copy_xj, g, l.aggr; xj = m) - # in gru forward, hidden state is first argument, input is second - h = l.gru(m, h) + _, h = l.gru(m, h) end return h end diff --git a/GNNlib/src/layers/pool.jl b/GNNlib/src/layers/pool.jl index 991e18465..40f983689 100644 --- a/GNNlib/src/layers/pool.jl +++ b/GNNlib/src/layers/pool.jl @@ -31,9 +31,9 @@ function set2set_pool(l, g::GNNGraph, x::AbstractMatrix) qstar = zeros_like(x, (2*n_in, g.num_graphs)) h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) + state = (h, c) for t in 1:l.num_iters - h, c = l.lstm(qstar, (h, c)) # [n_in, n_graphs] - q = h + q, state = l.lstm(qstar, state) # [n_in, n_graphs] qn = broadcast_nodes(g, q) # [n_in, n_nodes] α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes] r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs] diff --git a/GraphNeuralNetworks/Project.toml b/GraphNeuralNetworks/Project.toml index b659338a8..37f423e28 100644 --- a/GraphNeuralNetworks/Project.toml +++ b/GraphNeuralNetworks/Project.toml @@ -5,6 +5,7 @@ version = "1.0.0-DEV" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" @@ -18,7 +19,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" -Flux = "0.15" +ConcreteStructs = "0.2.3" +Flux = "0.16.0" GNNGraphs = "1.4" GNNlib = "1" LinearAlgebra = "1" diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index c8df337c8..745e46aaa 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -3,12 +3,13 @@ module GraphNeuralNetworks using Statistics: mean using LinearAlgebra, Random using Flux -using Flux: glorot_uniform, leakyrelu, GRUCell, batch +using Flux: glorot_uniform, leakyrelu, GRUCell, batch, initialstates using MacroTools: @forward using NNlib using ChainRulesCore using Reexport: @reexport using MLUtils: zeros_like +using ConcreteStructs: @concrete using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, check_num_nodes, check_num_edges, @@ -49,12 +50,12 @@ include("layers/heteroconv.jl") export HeteroGraphConv include("layers/temporalconv.jl") -export TGCN, - A3TGCN, - GConvLSTM, - GConvGRU, - DCGRU, - EvolveGCNO +export GNNRecurrence, + GConvGRU, GConvGRUCell, + GConvLSTM, GConvLSTMCell, + DCGRU, DCGRUCell, + EvolveGCNO, EvolveGCNOCell, + TGCN, TGCNCell include("layers/pool.jl") export GlobalPool, diff --git a/GraphNeuralNetworks/src/layers/conv.jl b/GraphNeuralNetworks/src/layers/conv.jl index 2cdab6a4d..e3cf30fea 100644 --- a/GraphNeuralNetworks/src/layers/conv.jl +++ b/GraphNeuralNetworks/src/layers/conv.jl @@ -1,3 +1,9 @@ +# The implementations of the forward pass of the graph convolutional layers are in the `GNNlib` module, +# in the src/layers/conv.jl file. The `GNNlib` module is re-exported in the GraphNeuralNetworks module. +# This annoying for the readability of the code, as the user has to look at two different files to understand +# the implementation of a single layer, +# but it is done for GraphNeuralNetworks.jl and GNNLux.jl to be able to share the same code. + @doc raw""" GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight]) diff --git a/GraphNeuralNetworks/src/layers/pool.jl b/GraphNeuralNetworks/src/layers/pool.jl index 493ef6715..a7d9ceaef 100644 --- a/GraphNeuralNetworks/src/layers/pool.jl +++ b/GraphNeuralNetworks/src/layers/pool.jl @@ -155,12 +155,6 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) return Set2Set(lstm, n_iters) end -function initialstates(cell::LSTMCell) - h = zeros_like(cell.Wh, size(cell.Wh, 2)) - c = zeros_like(cell.Wh, size(cell.Wh, 2)) - return h, c -end - function (l::Set2Set)(g, x) return GNNlib.set2set_pool(l, g, x) end diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index 67e85356d..bc12f2fab 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -1,578 +1,885 @@ -struct TGCNCell{C,G} <: GNNLayer - conv::C - gru::G - din::Int - dout::Int +function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T} + y = [] + for xt in eachslice(x, dims = 2) + yt, state = cell(g, xt, state) + y = vcat(y, [yt]) + end + return stack(y, dims = 2) end -Flux.@layer TGCNCell +function scan(cell, tg::TemporalSnapshotsGNNGraph, x::AbstractVector, state) + # @assert length(x) == length(tg.snapshots) + y = [] + for (t, xt) in enumerate(x) + gt = tg.snapshots[t] + yt, state = cell(gt, xt, state) + y = vcat(y, [yt]) + end + return y +end + + +""" + GNNRecurrence(cell) + +Construct a recurrent layer that applies the graph recurrent `cell` forward +multiple times to process an entire temporal sequence of node features at once. + +The `cell` has to satisfy the following interface for the forward pass: +`yt, state = cell(g, xt, state)`, where `xt` is the input node features, +`yt` is the updated node features, `state` is the cell state to be updated. + +# Forward + + layer(g, x, [state]) + +Applies the recurrent cell to each timestep of the input sequence. + +## Arguments + +- `g`: The input `GNNGraph` or `TemporalSnapshotsGNNGraph`. + - If `GNNGraph`, the same graph is used for all timesteps. + - If `TemporalSnapshotsGNNGraph`, a different graph is used for each timestep. Not all cells support this. +- `x`: The time-varying node features. + - If `g` is `GNNGraph`, it is an array of size `in x timesteps x num_nodes`. + - If `g` is `TemporalSnapshotsGNNGraph`, it is an vector of length `timesteps`, + with element `t` of size `in x num_nodes_t`. +- `state`: The initial state for the cell. + If not provided, it is generated by calling `Flux.initialstates(cell)`. + +## Return + +Returns the updated node features: +- If `g` is `GNNGraph`, returns an array of size `out_features x timesteps x num_nodes`. +- If `g` is `TemporalSnapshotsGNNGraph`, returns a vector of length `timesteps`, + with element `t` of size `out_features x num_nodes_t`. + +# Examples + +The following example considers a static graph and a time-varying node features. + +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); +GNNGraph: + num_nodes: 5 + num_edges: 10 + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> cell = GConvLSTMCell(d_in => d_out, 2) +GConvLSTMCell(2 => 3, 2) # 168 parameters + +julia> layer = GNNRecurrence(cell) +GNNRecurrence( + GConvLSTMCell(2 => 3, 2), # 168 parameters +) # Total: 24 arrays, 168 parameters, 2.023 KiB. + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +Now consider a time-varying graph and time-varying node features. +```jldoctest +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> num_nodes = [10, 10, 10, 10, 10]; + +julia> num_edges = [10, 12, 14, 16, 18]; + +julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) -function TGCNCell(ch::Pair{Int, Int}; - bias::Bool = true, - init = Flux.glorot_uniform, - add_self_loops = false) - din, dout = ch - conv = GCNConv(din => dout, sigmoid; init, bias, add_self_loops) - gru = GRUCell(dout => dout) - return TGCNCell(conv, gru, din, dout) +julia> x = [rand(Float32, d_in, n) for n in num_nodes]; + +julia> cell = EvolveGCNOCell(d_in => d_out) +EvolveGCNOCell(2 => 3) # 321 parameters + +julia> layer = GNNRecurrence(cell) +GNNRecurrence( + EvolveGCNOCell(2 => 3), # 321 parameters +) # Total: 5 arrays, 321 parameters, 1.535 KiB. + +julia> y = layer(tg, x); + +julia> length(y) # timesteps +5 + +julia> size(y[end]) # (d_out, num_nodes[end]) +(3, 10) +``` +""" +struct GNNRecurrence{G} <: GNNLayer + cell::G end -initialstates(cell::GRUCell) = zeros_like(cell.Wh, size(cell.Wh, 2)) -initialstates(cell::TGCNCell) = initialstates(cell.gru) -(cell::TGCNCell)(g::GNNGraph, x::AbstractVecOrMat) = cell(g, x, initialstates(cell)) +Flux.@layer GNNRecurrence -function (cell::TGCNCell)(g::GNNGraph, x::AbstractVecOrMat, h::AbstractVecOrMat) - x = cell.conv(g, x) - h = cell.gru(x, h) - return h +Flux.initialstates(rnn::GNNRecurrence) = Flux.initialstates(rnn.cell) + +function (rnn::GNNRecurrence)(g, x) + return rnn(g, x, initialstates(rnn)) end -function Base.show(io::IO, cell::TGCNCell) - print(io, "TGCNCell($(cell.din) => $(cell.dout))") +function (rnn::GNNRecurrence)(g, x, state) + return scan(rnn.cell, g, x, state) end +function Base.show(io::IO, rnn::GNNRecurrence) + print(io, "GNNRecurrence($(rnn.cell))") +end + + """ - TGCN(din => dout; [bias, init, add_self_loops]) + GConvGRUCell(in => out, k; [bias, init]) -Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). +Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent cell from the paper +[Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/abs/1612.07659). -Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. +Uses [`ChebConv`](@ref) to model spatial dependencies, +followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. # Arguments -- `din`: Number of input features. -- `dout`: Number of output features. +- `in => out`: A pair where `in` is the number of input node features and `out` + is the number of output node features. +- `k`: Chebyshev polynomial order. - `bias`: Add learnable bias. Default `true`. -- `init`: Convolution's weights initializer. Default `glorot_uniform`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. +- `init`: Weights' initializer. Default `glorot_uniform`. # Forward - tgcn(g::GNNGraph, x, [h]) + cell(g::GNNGraph, x, [h]) - `g`: The input graph. -- `x`: The input to the TGCN. It should be a matrix size `din x timesteps` or an array of size `din x timesteps x num_nodes`. -- `h`: The initial hidden state of the GRU cell. If given, it is a vector of size `out` or a matrix of size `dout x num_nodes`. - If not provided, it is assumed to be a vector of zeros. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `h`: The current hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. + If not provided, it is assumed to be a matrix of zeros. + +Performs one recurrence step and returns a tuple `(h, h)`, +where `h` is the updated hidden state of the GRU cell. # Examples ```jldoctest -julia> din, dout = 2, 3; +julia> using GraphNeuralNetworks, Flux + +julia> num_nodes, num_edges = 5, 10; -julia> tgcn = TGCN(din => dout) -TGCN( - TGCNCell( - GCNConv(2 => 3, σ), # 9 parameters - GRUCell(3 => 3), # 63 parameters - ), -) # Total: 5 arrays, 72 parameters, 560 bytes. +julia> d_in, d_out = 2, 3; -julia> num_nodes = 5; num_edges = 10; timesteps = 4; +julia> timesteps = 5; julia> g = rand_graph(num_nodes, num_edges); -julia> x = rand(Float32, din, timesteps, num_nodes); +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; + +julia> cell = GConvGRUCell(d_in => d_out, 2); + +julia> state = Flux.initialstates(cell); + +julia> y = state; -julia> tgcn(g, x) |> size -(3, 4, 5) +julia> for xt in x + y, state = cell(g, xt, state) + end + +julia> size(y) # (d_out, num_nodes) +(3, 5) ``` """ -struct TGCN{C<:TGCNCell} <: GNNLayer - cell::C +@concrete struct GConvGRUCell <: GNNLayer + conv_x_r + conv_h_r + conv_x_z + conv_h_z + conv_x_h + conv_h_h + k::Int + in::Int + out::Int end -Flux.@layer TGCN +Flux.@layer :noexpand GConvGRUCell + +function GConvGRUCell(ch::Pair{Int, Int}, k::Int; + bias::Bool = true, + init = Flux.glorot_uniform, + ) + in, out = ch + # reset gate + conv_x_r = ChebConv(in => out, k; bias, init) + conv_h_r = ChebConv(out => out, k; bias, init) + # update gate + conv_x_z = ChebConv(in => out, k; bias, init) + conv_h_z = ChebConv(out => out, k; bias, init) + # new gate + conv_x_h = ChebConv(in => out, k; bias, init) + conv_h_h = ChebConv(out => out, k; bias, init) + return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, in, out) +end -TGCN(ch::Pair{Int, Int}; kws...) = TGCN(TGCNCell(ch; kws...)) +function Flux.initialstates(cell::GConvGRUCell) + zeros_like(cell.conv_x_r.weight, cell.out) +end -initialstates(tgcn::TGCN) = initialstates(tgcn.cell) +(cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) -(tgcn::TGCN)(g::GNNGraph, x) = tgcn(g, x, initialstates(tgcn)) +function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) + h = repeat(h, 1, g.num_nodes) + return cell(g, x, h) +end -function (tgcn::TGCN)(g::GNNGraph, x::AbstractArray, h) - @assert ndims(x) == 2 || ndims(x) == 3 - # [x] = [din, timesteps] or [din, timesteps, num_nodes] - # y = AbstractArray[] # issue https://github.com/JuliaLang/julia/issues/56771 - y = [] - for xt in eachslice(x, dims = 2) - h = tgcn.cell(g, xt, h) - y = vcat(y, [h]) +function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + # reset gate + r = cell.conv_x_r(g, x) .+ cell.conv_h_r(g, h) + r = Flux.sigmoid_fast(r) + # update gate + z = cell.conv_x_z(g, x) .+ cell.conv_h_z(g, h) + z = Flux.sigmoid_fast(z) + # new gate + h̃ = cell.conv_x_h(g, x) .+ cell.conv_h_h(g, r .* h) + h̃ = Flux.tanh_fast(h̃) + h = (1 .- z) .* h̃ .+ z .* h + return h, h +end + +function Base.show(io::IO, cell::GConvGRUCell) + print(io, "GConvGRUCell($(cell.in) => $(cell.out), $(cell.k))") +end + +""" + GConvGRU(args...; kws...) + +Construct a recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell. +It can be used to process an entire temporal sequence of node features at once. + +The arguments are passed to the [`GConvGRUCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. + +# Examples + +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> layer = GConvGRU(d_in => d_out, 2) +GConvGRU( + GConvGRUCell(2 => 3, 2), # 108 parameters +) # Total: 12 arrays, 108 parameters, 1.148 KiB. + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +""" +GConvGRU(args...; kws...) = GNNRecurrence(GConvGRUCell(args...; kws...)) + + +""" + GConvLSTMCell(in => out, k; [bias, init]) + +Graph Convolutional LSTM recurrent cell from the paper +[Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/abs/1612.07659). + +Uses [`ChebConv`](@ref) to model spatial dependencies, +followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. + +# Arguments + +- `in => out`: A pair where `in` is the number of input node features and `out` + is the number of output node features. +- `k`: Chebyshev polynomial order. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. + +# Forward + + cell(g::GNNGraph, x, [state]) + +- `g`: The input graph. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `state`: The current state of the LSTM cell. + If given, it is a tuple `(h, c)` where both `h` and `c` are arrays of size `out x num_nodes`. + If not provided, it is assumed to be a tuple of matrices of zeros. + +Performs one recurrence step and returns a tuple `(output, state)`, +where `output` is the updated hidden state `h` of the LSTM cell and `state` is the updated tuple `(h, c)`. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks, Flux + +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; + +julia> cell = GConvLSTMCell(d_in => d_out, 2); + +julia> state = Flux.initialstates(cell); + +julia> y = state[1]; + +julia> for xt in x + y, state = cell(g, xt, state) + end + +julia> size(y) # (d_out, num_nodes) +(3, 5) +``` +""" +@concrete struct GConvLSTMCell <: GNNLayer + conv_x_i + conv_h_i + w_i + b_i + conv_x_f + conv_h_f + w_f + b_f + conv_x_c + conv_h_c + w_c + b_c + conv_x_o + conv_h_o + w_o + b_o + k::Int + in::Int + out::Int +end + +Flux.@layer :noexpand GConvLSTMCell + +function GConvLSTMCell(ch::Pair{Int, Int}, k::Int; + bias::Bool = true, + init = Flux.glorot_uniform) + in, out = ch + # input gate + conv_x_i = ChebConv(in => out, k; bias, init) + conv_h_i = ChebConv(out => out, k; bias, init) + w_i = init(out, 1) + b_i = bias ? Flux.create_bias(w_i, true, out) : false + # forget gate + conv_x_f = ChebConv(in => out, k; bias, init) + conv_h_f = ChebConv(out => out, k; bias, init) + w_f = init(out, 1) + b_f = bias ? Flux.create_bias(w_f, true, out) : false + # cell state + conv_x_c = ChebConv(in => out, k; bias, init) + conv_h_c = ChebConv(out => out, k; bias, init) + w_c = init(out, 1) + b_c = bias ? Flux.create_bias(w_c, true, out) : false + # output gate + conv_x_o = ChebConv(in => out, k; bias, init) + conv_h_o = ChebConv(out => out, k; bias, init) + w_o = init(out, 1) + b_o = bias ? Flux.create_bias(w_o, true, out) : false + return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i, + conv_x_f, conv_h_f, w_f, b_f, + conv_x_c, conv_h_c, w_c, b_c, + conv_x_o, conv_h_o, w_o, b_o, + k, in, out) +end + +function Flux.initialstates(cell::GConvLSTMCell) + (zeros_like(cell.conv_x_i.weight, cell.out), zeros_like(cell.conv_x_i.weight, cell.out)) +end + +(cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) + +function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c)) + if h isa AbstractVector + h = repeat(h, 1, g.num_nodes) + end + if c isa AbstractVector + c = repeat(c, 1, g.num_nodes) end - return stack(y, dims = 2) # [dout, timesteps, num_nodes] -end - -Base.show(io::IO, tgcn::TGCN) = print(io, "TGCN($(tgcn.cell.din) => $(tgcn.cell.dout))") - -######## TO BE PORTED TO FLUX v0.15 from here ############################ - -# """ -# A3TGCN(din => dout; [bias, init, add_self_loops]) - -# Attention Temporal Graph Convolutional Network (A3T-GCN) model from the paper [A3T-GCN: Attention Temporal Graph -# Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583.pdf). - -# Performs a TGCN layer, followed by a soft attention layer. - -# # Arguments - -# - `din`: Number of input features. -# - `dout`: Number of output features. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Convolution's weights initializer. Default `glorot_uniform`. -# - `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. - -# # Examples - -# ```jldoctest -# julia> din, dout = 2, 3; - -# julia> model = A3TGCN(din => dout) -# TGCN( -# TGCNCell( -# GCNConv(2 => 3, σ), # 9 parameters -# GRUCell(3 => 3), # 63 parameters -# ), -# ) # Total: 5 arrays, 72 parameters, 560 bytes. - -# julia> num_nodes = 5; num_edges = 10; timesteps = 4; - -# julia> g = rand_graph(num_nodes, num_edges); - -# julia> x = rand(Float32, din, timesteps, num_nodes); - -# julia> model(g, x) |> size -# (3, 4, 5) -# ``` - -# !!! warning "Batch size changes" -# Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. -# """ -# struct A3TGCN <: GNNLayer -# tgcn::TGCN -# dense1::Dense -# dense2::Dense -# din::Int -# dout::Int -# end - -# Flux.@layer A3TGCN - -# function A3TGCN(ch::Pair{Int, Int}, -# bias::Bool = true, -# init = Flux.glorot_uniform, -# add_self_loops = false) -# din, dout = ch -# tgcn = TGCN(din => dout; bias, init, init_state, add_self_loops) -# dense1 = Dense(dout => dout) -# dense2 = Dense(dout => dout) -# return A3TGCN(tgcn, dense1, dense2, din, dout) -# end - -# function (a3tgcn::A3TGCN)(g::GNNGraph, x::AbstractArray, h) -# h = a3tgcn.tgcn(g, x, h) -# e = a3tgcn.dense1(h) # WHY NOT RELU? -# e = a3tgcn.dense2(e) -# a = softmax(e, dims = 2) -# c = sum(a .* h , dims = 2) -# if length(size(c)) == 3 -# c = dropdims(c, dims = 2) -# end -# return c -# end - -# function Base.show(io::IO, a3tgcn::A3TGCN) -# print(io, "A3TGCN($(a3tgcn.din) => $(a3tgcn.dout))") -# end - -# struct GConvGRUCell <: GNNLayer -# conv_x_r::ChebConv -# conv_h_r::ChebConv -# conv_x_z::ChebConv -# conv_h_z::ChebConv -# conv_x_h::ChebConv -# conv_h_h::ChebConv -# k::Int -# state0 -# in::Int -# out::Int -# end - -# Flux.@layer GConvGRUCell - -# function GConvGRUCell(ch::Pair{Int, Int}, k::Int, n::Int; -# bias::Bool = true, -# init = Flux.glorot_uniform, -# init_state = Flux.zeros32) -# in, out = ch -# # reset gate -# conv_x_r = ChebConv(in => out, k; bias, init) -# conv_h_r = ChebConv(out => out, k; bias, init) -# # update gate -# conv_x_z = ChebConv(in => out, k; bias, init) -# conv_h_z = ChebConv(out => out, k; bias, init) -# # new gate -# conv_x_h = ChebConv(in => out, k; bias, init) -# conv_h_h = ChebConv(out => out, k; bias, init) -# state0 = init_state(out, n) -# return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, state0, in, out) -# end - -# function (ggru::GConvGRUCell)(h, g::GNNGraph, x) -# r = ggru.conv_x_r(g, x) .+ ggru.conv_h_r(g, h) -# r = Flux.sigmoid_fast(r) -# z = ggru.conv_x_z(g, x) .+ ggru.conv_h_z(g, h) -# z = Flux.sigmoid_fast(z) -# h̃ = ggru.conv_x_h(g, x) .+ ggru.conv_h_h(g, r .* h) -# h̃ = Flux.tanh_fast(h̃) -# h = (1 .- z) .* h̃ .+ z .* h -# return h, h -# end - -# function Base.show(io::IO, ggru::GConvGRUCell) -# print(io, "GConvGRUCell($(ggru.in) => $(ggru.out))") -# end - -# """ -# GConvGRU(in => out, k, n; [bias, init, init_state]) - -# Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). - -# Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `k`: Chebyshev polynomial order. -# - `n`: Number of nodes in the graph. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -# julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes); - -# julia> y = ggru(g1, x1); - -# julia> size(y) -# (5, 5) - -# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -# julia> z = ggru(g2, x2); - -# julia> size(z) -# (5, 5, 30) -# ``` -# """ -# # GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...)) -# # Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0) - -# # (l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -# # _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x) -# # _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g) - -# struct GConvLSTMCell <: GNNLayer -# conv_x_i::ChebConv -# conv_h_i::ChebConv -# w_i -# b_i -# conv_x_f::ChebConv -# conv_h_f::ChebConv -# w_f -# b_f -# conv_x_c::ChebConv -# conv_h_c::ChebConv -# w_c -# b_c -# conv_x_o::ChebConv -# conv_h_o::ChebConv -# w_o -# b_o -# k::Int -# state0 -# in::Int -# out::Int -# end - -# Flux.@layer GConvLSTMCell - -# function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int; -# bias::Bool = true, -# init = Flux.glorot_uniform, -# init_state = Flux.zeros32) -# in, out = ch -# # input gate -# conv_x_i = ChebConv(in => out, k; bias, init) -# conv_h_i = ChebConv(out => out, k; bias, init) -# w_i = init(out, 1) -# b_i = bias ? Flux.create_bias(w_i, true, out) : false -# # forget gate -# conv_x_f = ChebConv(in => out, k; bias, init) -# conv_h_f = ChebConv(out => out, k; bias, init) -# w_f = init(out, 1) -# b_f = bias ? Flux.create_bias(w_f, true, out) : false -# # cell state -# conv_x_c = ChebConv(in => out, k; bias, init) -# conv_h_c = ChebConv(out => out, k; bias, init) -# w_c = init(out, 1) -# b_c = bias ? Flux.create_bias(w_c, true, out) : false -# # output gate -# conv_x_o = ChebConv(in => out, k; bias, init) -# conv_h_o = ChebConv(out => out, k; bias, init) -# w_o = init(out, 1) -# b_o = bias ? Flux.create_bias(w_o, true, out) : false -# state0 = (init_state(out, n), init_state(out, n)) -# return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i, -# conv_x_f, conv_h_f, w_f, b_f, -# conv_x_c, conv_h_c, w_c, b_c, -# conv_x_o, conv_h_o, w_o, b_o, -# k, state0, in, out) -# end - -# function (gclstm::GConvLSTMCell)((h, c), g::GNNGraph, x) -# # input gate -# i = gclstm.conv_x_i(g, x) .+ gclstm.conv_h_i(g, h) .+ gclstm.w_i .* c .+ gclstm.b_i -# i = Flux.sigmoid_fast(i) -# # forget gate -# f = gclstm.conv_x_f(g, x) .+ gclstm.conv_h_f(g, h) .+ gclstm.w_f .* c .+ gclstm.b_f -# f = Flux.sigmoid_fast(f) -# # cell state -# c = f .* c .+ i .* Flux.tanh_fast(gclstm.conv_x_c(g, x) .+ gclstm.conv_h_c(g, h) .+ gclstm.w_c .* c .+ gclstm.b_c) -# # output gate -# o = gclstm.conv_x_o(g, x) .+ gclstm.conv_h_o(g, h) .+ gclstm.w_o .* c .+ gclstm.b_o -# o = Flux.sigmoid_fast(o) -# h = o .* Flux.tanh_fast(c) -# return (h,c), h -# end - -# function Base.show(io::IO, gclstm::GConvLSTMCell) -# print(io, "GConvLSTMCell($(gclstm.in) => $(gclstm.out))") -# end - -# """ -# GConvLSTM(in => out, k, n; [bias, init, init_state]) - -# Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). - -# Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `k`: Chebyshev polynomial order. -# - `n`: Number of nodes in the graph. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -# julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes); - -# julia> y = gclstm(g1, x1); - -# julia> size(y) -# (5, 5) - -# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -# julia> z = gclstm(g2, x2); - -# julia> size(z) -# (5, 5, 30) -# ``` -# """ -# # GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...)) -# # Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0) - -# # (l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -# # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x) -# # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g) - -# struct DCGRUCell -# in::Int -# out::Int -# state0 -# k::Int -# dconv_u::DConv -# dconv_r::DConv -# dconv_c::DConv -# end - -# Flux.@layer DCGRUCell - -# function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) -# in, out = ch -# dconv_u = DConv((in + out) => out, k; bias=bias, init=init) -# dconv_r = DConv((in + out) => out, k; bias=bias, init=init) -# dconv_c = DConv((in + out) => out, k; bias=bias, init=init) -# state0 = init_state(out, n) -# return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c) -# end - -# function (dcgru::DCGRUCell)(h, g::GNNGraph, x) -# h̃ = vcat(x, h) -# z = dcgru.dconv_u(g, h̃) -# z = NNlib.sigmoid_fast.(z) -# r = dcgru.dconv_r(g, h̃) -# r = NNlib.sigmoid_fast.(r) -# ĥ = vcat(x, h .* r) -# c = dcgru.dconv_c(g, ĥ) -# c = tanh.(c) -# h = z.* h + (1 .- z) .* c -# return h, h -# end - -# function Base.show(io::IO, dcgru::DCGRUCell) -# print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))") -# end - -# """ -# DCGRU(in => out, k, n; [bias, init, init_state]) - -# Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural -# Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). - -# Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `k`: Diffusion step. -# - `n`: Number of nodes in the graph. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -# julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes); - -# julia> y = dcgru(g1, x1); - -# julia> size(y) -# (5, 5) - -# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -# julia> z = dcgru(g2, x2); - -# julia> size(z) -# (5, 5, 30) -# ``` -# """ -# # DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...)) -# # Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) - -# # (l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -# # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) -# # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) - -# """ -# EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) - -# Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). - -# Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. - - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) -# TemporalSnapshotsGNNGraph: -# num_nodes: [10, 10, 10] -# num_edges: [20, 14, 22] -# num_snapshots: 3 - -# julia> ev = EvolveGCNO(4 => 5) -# EvolveGCNO(4 => 5) - -# julia> size(ev(tg, tg.ndata.x)) -# (3,) - -# julia> size(ev(tg, tg.ndata.x)[1]) -# (5, 10) -# ``` -# """ -# struct EvolveGCNO -# conv -# W_init -# init_state -# in::Int -# out::Int -# Wf -# Uf -# Bf -# Wi -# Ui -# Bi -# Wo -# Uo -# Bo -# Wc -# Uc -# Bc -# end - -# function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) -# in, out = ch -# W = init(out, in) -# conv = GCNConv(ch; bias = bias, init = init) -# Wf = init(out, in) -# Uf = init(out, in) -# Bf = bias ? init(out, in) : nothing -# Wi = init(out, in) -# Ui = init(out, in) -# Bi = bias ? init(out, in) : nothing -# Wo = init(out, in) -# Uo = init(out, in) -# Bo = bias ? init(out, in) : nothing -# Wc = init(out, in) -# Uc = init(out, in) -# Bc = bias ? init(out, in) : nothing -# return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) -# end - -# function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) -# H = egcno.init_state(egcno.out, egcno.in) -# C = egcno.init_state(egcno.out, egcno.in) -# W = egcno.W_init -# X = map(1:tg.num_snapshots) do i -# F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) -# I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) -# O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) -# C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) -# C = F .* C + I .* C̃ -# H = O .* tanh_fast.(C) -# W = H -# egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) -# end -# return X -# end - -# function Base.show(io::IO, egcno::EvolveGCNO) -# print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") -# end + @assert ndims(h) == 2 && ndims(c) == 2 + # input gate + i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i + i = Flux.sigmoid_fast(i) + # forget gate + f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f + f = Flux.sigmoid_fast(f) + # cell state + c = f .* c .+ i .* Flux.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c) + # output gate + o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o + o = Flux.sigmoid_fast(o) + h = o .* Flux.tanh_fast(c) + return h, (h, c) +end + +function Base.show(io::IO, cell::GConvLSTMCell) + print(io, "GConvLSTMCell($(cell.in) => $(cell.out), $(cell.k))") +end + + +""" + GConvLSTM(args...; kws...) + +Construct a recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell. +It can be used to process an entire temporal sequence of node features at once. + +The arguments are passed to the [`GConvLSTMCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. + +# Examples + +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> layer = GConvLSTM(d_in => d_out, 2) +GNNRecurrence( + GConvLSTMCell(2 => 3, 2), # 168 parameters +) # Total: 24 arrays, 168 parameters, 2.023 KiB. + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +""" +GConvLSTM(args...; kws...) = GNNRecurrence(GConvLSTMCell(args...; kws...)) + +""" + DCGRUCell(in => out, k; [bias, init]) + +Diffusion Convolutional Recurrent Neural Network (DCGRU) cell from the paper +[Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/abs/1707.01926). + +Applyis a [`DConv`](@ref) layer to model spatial dependencies, +in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input node features. +- `out`: Number of output node features. +- `k`: Diffusion step for the `DConv`. +- `bias`: Add learnable bias. Default `true`. +- `init`: Convolution weights' initializer. Default `glorot_uniform`. + +# Forward + + cell(g::GNNGraph, x, [h]) + +- `g`: The input graph. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `h`: The current state of the GRU cell. It is a matrix of size `out x num_nodes`. + If not provided, it is assumed to be a matrix of zeros. + +Performs one recurrence step and returns a tuple `(h, h)`, +where `h` is the updated hidden state of the GRU cell. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks, Flux + +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; + +julia> cell = DCGRUCell(d_in => d_out, 2); + +julia> state = Flux.initialstates(cell); + +julia> y = state; + +julia> for xt in x + y, state = cell(g, xt, state) + end + +julia> size(y) # (d_out, num_nodes) +(3, 5) +``` +""" +struct DCGRUCell + in::Int + out::Int + k::Int + dconv_u::DConv + dconv_r::DConv + dconv_c::DConv +end + +Flux.@layer :noexpand DCGRUCell + +function DCGRUCell(ch::Pair{Int,Int}, k::Int; bias = true, init = glorot_uniform) + in, out = ch + dconv_u = DConv((in + out) => out, k; bias, init) + dconv_r = DConv((in + out) => out, k; bias, init) + dconv_c = DConv((in + out) => out, k; bias, init) + return DCGRUCell(in, out, k, dconv_u, dconv_r, dconv_c) +end + +Flux.initialstates(cell::DCGRUCell) = zeros_like(cell.dconv_u.weights, cell.out) + +(cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) + +function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) + return cell(g, x, repeat(h, 1, g.num_nodes)) +end + +function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + h̃ = vcat(x, h) + z = cell.dconv_u(g, h̃) + z = NNlib.sigmoid_fast.(z) + r = cell.dconv_r(g, h̃) + r = NNlib.sigmoid_fast.(r) + ĥ = vcat(x, h .* r) + c = cell.dconv_c(g, ĥ) + c = NNlib.tanh_fast.(c) + h = z.* h + (1 .- z) .* c + return h, h +end + +function Base.show(io::IO, cell::DCGRUCell) + print(io, "DCGRUCell($(cell.in) => $(cell.out), $(cell.k))") +end + +""" + DCGRU(args...; kws...) + +Construct a recurrent layer corresponding to the [`DCGRUCell`](@ref) cell. +It can be used to process an entire temporal sequence of node features at once. + +The arguments are passed to the [`DCGRUCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. + +# Examples +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> layer = DCGRU(d_in => d_out, 2) +GNNRecurrence( + DCGRUCell(2 => 3, 2), # 189 parameters +) # Total: 6 arrays, 189 parameters, 1.184 KiB. + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +""" +DCGRU(args...; kws...) = GNNRecurrence(DCGRUCell(args...; kws...)) + +"""" + EvolveGCNOCell(in => out; bias = true, init = glorot_uniform) + +Evolving Graph Convolutional Network cell of type "-O" from the paper +[EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/abs/1902.10191). + +Uses a [`GCNConv`](@ref) layer to model spatial dependencies, and an `LSTMCell` to model temporal dependencies. +Can work with time-varying graphs and node features. + +# Arguments + +- `in => out`: A pair where `in` is the number of input node features and `out` + is the number of output node features. +- `bias`: Add learnable bias for the convolution and the lstm cell. Default `true`. +- `init`: Weights' initializer for the convolution. Default `glorot_uniform`. + +# Forward + + cell(g::GNNGraph, x, [state]) -> x, state + +- `g`: The input graph. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `state`: The current state of the cell. + A state is a tuple `(weight, lstm)` where `weight` is the convolution's weight and `lstm` is the lstm's state. + If not provided, it is generated by calling `Flux.initialstates(cell)`. + +Returns the updated node features `x` and the updated state. + +```jldoctest +julia> using GraphNeuralNetworks, Flux + +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = [rand_graph(num_nodes, num_edges) for t in 1:timesteps]; + +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; + +julia> cell1 = EvolveGCNOCell(d_in => d_out) +EvolveGCNOCell(2 => 3) # 321 parameters + +julia> cell2 = EvolveGCNOCell(d_out => d_out) +EvolveGCNOCell(3 => 3) # 696 parameters + +julia> state1 = Flux.initialstates(cell1); + +julia> state2 = Flux.initialstates(cell2); + +julia> outputs = []; + +julia> for t in 1:timesteps + zt, state1 = cell1(g[t], x[t], state1) + yt, state2 = cell2(g[t], zt, state2) + outputs = vcat(outputs, [yt]) + end + +julia> size(outputs[end]) # (d_out, num_nodes) +(3, 5) +``` +""" +struct EvolveGCNOCell{C,L} <: GNNLayer + in::Int + out::Int + conv::C + lstm::L +end + +Flux.@layer :noexpand EvolveGCNOCell + +function EvolveGCNOCell((in,out)::Pair{Int,Int}; bias = true, init = glorot_uniform) + conv = GCNConv(in => out; bias, init) + lstm = LSTMCell(in*out => in*out; bias) + return EvolveGCNOCell(in, out, conv, lstm) +end + +function Flux.initialstates(cell::EvolveGCNOCell) + weight = reshape(cell.conv.weight, :) + lstm = Flux.initialstates(cell.lstm) + return (; weight, lstm) +end + +(cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) + +function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state) + weight, state_lstm = cell.lstm(state.weight, state.lstm) + x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in))) + return x, (; weight, lstm = state_lstm) +end + +function Base.show(io::IO, egcno::EvolveGCNOCell) + print(io, "EvolveGCNOCell($(egcno.in) => $(egcno.out))") +end + + +""" + EvolveGCNO(args...; kws...) + +Construct a recurrent layer corresponding to the [`EvolveGCNOCell`](@ref) cell. +It can be used to process an entire temporal sequence of graphs and node features at once. + +The arguments are passed to the [`EvolveGCNOCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. + +# Examples + +```jldoctest +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> num_nodes = [10, 10, 10, 10, 10]; + +julia> num_edges = [10, 12, 14, 16, 18]; + +julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) + +julia> x = [rand(Float32, d_in, n) for n in num_nodes]; + +julia> cell = EvolveGCNO(d_in => d_out) +GNNRecurrence( + EvolveGCNOCell(2 => 3), # 321 parameters +) # Total: 5 arrays, 321 parameters, 1.535 KiB. + +julia> y = layer(tg, x); + +julia> length(y) # timesteps +5 + +julia> size(y[end]) # (d_out, num_nodes[end]) +(3, 10) +``` +""" +EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...)) + + +""" + TGCNCell(in => out; kws...) + +Recurrent graph convolutional cell from the paper +[T-GCN: A Temporal Graph Convolutional +Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320). + +Uses two stacked [`GCNConv`](@ref) layers to model spatial dependencies, +and a GRU mechanism to model temporal dependencies. + +`in` and `out` are the number of input and output node features, respectively. +The keyword arguments are passed to the [`GCNConv`](@ref) constructor. + +# Forward + + cell(g::GNNGraph, x, [state]) + +- `g`: The input graph. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `state`: The current state of the cell. + If not provided, it is generated by calling `Flux.initialstates(cell)`. + The state is a matrix of size `out x num_nodes`. + +Returns the updated node features and the updated state. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks, Flux + +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; + +julia> cell = DCGRUCell(d_in => d_out, 2); + +julia> state = Flux.initialstates(cell); + +julia> y = state; + +julia> for xt in x + y, state = cell(g, xt, state) + end + +julia> size(y) # (d_out, num_nodes) +(3, 5) +``` +""" +@concrete struct TGCNCell <: GNNLayer + in::Int + out::Int + conv_z + dense_z + conv_r + dense_r + conv_h + dense_h +end + +Flux.@layer :noexpand TGCNCell + +function TGCNCell((in, out)::Pair{Int, Int}; kws...) + conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...)) + dense_z = Dense(2*out => out, sigmoid) + conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...)) + dense_r = Dense(2*out => out, sigmoid) + conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...)) + dense_h = Dense(2*out => out, tanh) + return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h) +end + +Flux.initialstates(cell::TGCNCell) = zeros_like(cell.dense_z.weight, cell.out) + +(cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) + +function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) + return cell(g, x, repeat(h, 1, g.num_nodes)) +end + +function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + z = cell.conv_z(g, x) + z = cell.dense_z(vcat(z, h)) + r = cell.conv_r(g, x) + r = cell.dense_r(vcat(r, h)) + h̃ = cell.conv_h(g, x) + h̃ = cell.dense_h(vcat(h̃, r .* h)) + h = (1 .- z) .* h .+ z .* h̃ + return h, h +end + +function Base.show(io::IO, cell::TGCNCell) + print(io, "TGCNCell($(cell.in) => $(cell.out))") +end + +""" + TGCN(args...; kws...) + +Construct a recurrent layer corresponding to the [`TGCNCell`](@ref) cell. + +The arguments are passed to the [`TGCNCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. + +# Examples + +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> layer = TGCN(d_in => d_out) + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +""" +TGCN(args...; kws...) = GNNRecurrence(TGCNCell(args...; kws...)) + diff --git a/GraphNeuralNetworks/test/layers/temporalconv.jl b/GraphNeuralNetworks/test/layers/temporalconv.jl index 277783d4f..f8d96c0e2 100644 --- a/GraphNeuralNetworks/test/layers/temporalconv.jl +++ b/GraphNeuralNetworks/test/layers/temporalconv.jl @@ -1,16 +1,20 @@ @testmodule TemporalConvTestModule begin using GraphNeuralNetworks - export in_channel, out_channel, N, timesteps, g, tg, RTOL_LOW, RTOL_HIGH, ATOL_LOW + using Statistics + export in_channel, out_channel, N, timesteps, g, tg, cell_loss, + RTOL_LOW, ATOL_LOW, RTOL_HIGH RTOL_LOW = 1e-2 - RTOL_HIGH = 1e-5 ATOL_LOW = 1e-3 + RTOL_HIGH = 1e-5 in_channel = 3 out_channel = 5 N = 4 timesteps = 5 + cell_loss(cell, g, x...) = mean(cell(g, x...)[1]) + g = GNNGraph(rand_graph(N, 8), ndata = rand(Float32, in_channel, N), graph_type = :coo) @@ -22,80 +26,163 @@ end @testitem "TGCNCell" setup=[TemporalConvTestModule, TestModule] begin using .TemporalConvTestModule, .TestModule cell = GraphNeuralNetworks.TGCNCell(in_channel => out_channel) - h = cell(g, g.x) + y, h = cell(g, g.x) + @test y === h @test size(h) == (out_channel, g.num_nodes) - test_gradients(cell, g, g.x, rtol = RTOL_HIGH) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH) + # with initial state + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH) end @testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin using .TemporalConvTestModule, .TestModule - tgcn = TGCN(in_channel => out_channel) + layer = TGCN(in_channel => out_channel) x = rand(Float32, in_channel, timesteps, g.num_nodes) - h = tgcn(g, x) - @test size(h) == (out_channel, timesteps, g.num_nodes) - test_gradients(tgcn, g, x, rtol = RTOL_HIGH) - test_gradients(tgcn, g, x, h[:,1,:], rtol = RTOL_HIGH) - - # model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) - # @test size(model(g1, g1.ndata.x)) == (1, N) - # @test model(g1) isa GNNGraph + state0 = rand(Float32, out_channel, g.num_nodes) + y = layer(g, x) + @test layer isa GNNRecurrence + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol = RTOL_HIGH) + # with initial state + test_gradients(layer, g, x, state0, rtol = RTOL_HIGH) + + # interplay with GNNChain + model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW) end -# @testitem "A3TGCN" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# a3tgcn = A3TGCN(in_channel => out_channel) -# @test size(Flux.gradient(x -> sum(a3tgcn(g1, x)), g1.ndata.x)[1]) == (in_channel, N) -# model = GNNChain(A3TGCN(in_channel => out_channel), Dense(out_channel, 1)) -# @test size(model(g1, g1.ndata.x)) == (1, N) -# @test model(g1) isa GNNGraph -# end +@testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + cell = GConvLSTMCell(in_channel => out_channel, 2) + y, (h, c) = cell(g, g.x) + @test y === h + @test size(h) == (out_channel, g.num_nodes) + @test size(c) == (out_channel, g.num_nodes) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) +end -# @testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# gconvlstm = GraphNeuralNetworks.GConvLSTMCell(in_channel => out_channel, 2, g1.num_nodes) -# (h, c), h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) -# @test size(h) == (out_channel, N) -# @test size(c) == (out_channel, N) -# end +@testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + layer = GConvLSTM(in_channel => out_channel, 2) + @test layer isa GNNRecurrence + x = rand(Float32, in_channel, timesteps, g.num_nodes) + state0 = (rand(Float32, out_channel, g.num_nodes), rand(Float32, out_channel, g.num_nodes)) + y = layer(g, x) + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + + # interplay with GNNChain + model = GNNChain(GConvLSTM(in_channel => out_channel, 2), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) +end -# @testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# gconvlstm = GConvLSTM(in_channel => out_channel, 2, g1.num_nodes) -# @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) -# model = GNNChain(GConvLSTM(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) -# end +@testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + cell = GConvGRUCell(in_channel => out_channel, 2) + y, h = cell(g, g.x) + @test y === h + @test size(h) == (out_channel, g.num_nodes) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) +end -# @testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes) -# h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) -# @test size(h) == (out_channel, N) -# end -# @testitem "GConvGRU" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# gconvlstm = GConvGRU(in_channel => out_channel, 2, g1.num_nodes) -# @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) -# model = GNNChain(GConvGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) -# @test size(model(g1, g1.ndata.x)) == (1, N) -# @test model(g1) isa GNNGraph -# end +@testitem "GConvGRU" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + layer = GConvGRU(in_channel => out_channel, 2) + @test layer isa GNNRecurrence + x = rand(Float32, in_channel, timesteps, g.num_nodes) + state0 = rand(Float32, out_channel, g.num_nodes) + y = layer(g, x) + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + + # interplay with GNNChain + model = GNNChain(GConvGRU(in_channel => out_channel, 2), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) +end -# @testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes) -# @test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N) -# model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) -# @test size(model(g1, g1.ndata.x)) == (1, N) -# @test model(g1) isa GNNGraph -# end +@testitem "DCGRUCell" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + cell = DCGRUCell(in_channel => out_channel, 2) + y, h = cell(g, g.x) + @test y === h + @test size(h) == (out_channel, g.num_nodes) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) +end -# @testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# evolvegcno = EvolveGCNO(in_channel => out_channel) -# @test length(Flux.gradient(x -> sum(sum(evolvegcno(tg, x))), tg.ndata.x)[1]) == S -# @test size(evolvegcno(tg, tg.ndata.x)[1]) == (out_channel, N) -# end +@testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + layer = DCGRU(in_channel => out_channel, 2) + @test layer isa GNNRecurrence + x = rand(Float32, in_channel, timesteps, g.num_nodes) + state0 = rand(Float32, out_channel, g.num_nodes) + y = layer(g, x) + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + + # interplay with GNNChain + model = GNNChain(DCGRU(in_channel => out_channel, 2), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) +end + +@testitem "EvolveGCNOCell" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + cell = EvolveGCNOCell(in_channel => out_channel) + y, state = cell(g, g.x) + @test size(y) == (out_channel, g.num_nodes) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) +end + +@testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + layer = EvolveGCNO(in_channel => out_channel) + @test layer isa GNNRecurrence + x = rand(Float32, in_channel, timesteps, g.num_nodes) + state0 = Flux.initialstates(layer) + y = layer(g, x) + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + + # interplay with GNNChain + model = GNNChain(EvolveGCNO(in_channel => out_channel), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) +end # @testitem "GINConv" setup=[TemporalConvTestModule, TestModule] begin # using .TemporalConvTestModule, .TestModule diff --git a/GraphNeuralNetworks/test/test_module.jl b/GraphNeuralNetworks/test/test_module.jl index f21e0a298..8f7a0446b 100644 --- a/GraphNeuralNetworks/test/test_module.jl +++ b/GraphNeuralNetworks/test/test_module.jl @@ -71,6 +71,7 @@ function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) # @assert isapprox(x, y; rtol, atol) if !isapprox(x, y; rtol, atol) equal = false + # @show x y end end end