From 43aedc2da28b6298b60f258fc8cf0b989e3390c1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Sun, 15 Dec 2024 22:18:45 +0100 Subject: [PATCH 01/11] GConvGRU --- GraphNeuralNetworks/Project.toml | 4 +- .../src/GraphNeuralNetworks.jl | 7 +- GraphNeuralNetworks/src/layers/conv.jl | 6 + GraphNeuralNetworks/src/layers/pool.jl | 6 - .../src/layers/temporalconv.jl | 690 ++++-------------- .../src/layers/temporalconv_old.jl | 589 +++++++++++++++ 6 files changed, 752 insertions(+), 550 deletions(-) create mode 100644 GraphNeuralNetworks/src/layers/temporalconv_old.jl diff --git a/GraphNeuralNetworks/Project.toml b/GraphNeuralNetworks/Project.toml index b659338a8..37f423e28 100644 --- a/GraphNeuralNetworks/Project.toml +++ b/GraphNeuralNetworks/Project.toml @@ -5,6 +5,7 @@ version = "1.0.0-DEV" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c" GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48" @@ -18,7 +19,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "1" -Flux = "0.15" +ConcreteStructs = "0.2.3" +Flux = "0.16.0" GNNGraphs = "1.4" GNNlib = "1" LinearAlgebra = "1" diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index c8df337c8..10ff5f310 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -3,12 +3,13 @@ module GraphNeuralNetworks using Statistics: mean using LinearAlgebra, Random using Flux -using Flux: glorot_uniform, leakyrelu, GRUCell, batch +using Flux: glorot_uniform, leakyrelu, GRUCell, batch, initialstates using MacroTools: @forward using NNlib using ChainRulesCore using Reexport: @reexport using MLUtils: zeros_like +using ConcreteStructs: @concrete using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T, check_num_nodes, check_num_edges, @@ -49,10 +50,10 @@ include("layers/heteroconv.jl") export HeteroGraphConv include("layers/temporalconv.jl") -export TGCN, +export GConvGRU, GConvGRUCell, + TGCN, A3TGCN, GConvLSTM, - GConvGRU, DCGRU, EvolveGCNO diff --git a/GraphNeuralNetworks/src/layers/conv.jl b/GraphNeuralNetworks/src/layers/conv.jl index 2cdab6a4d..e3cf30fea 100644 --- a/GraphNeuralNetworks/src/layers/conv.jl +++ b/GraphNeuralNetworks/src/layers/conv.jl @@ -1,3 +1,9 @@ +# The implementations of the forward pass of the graph convolutional layers are in the `GNNlib` module, +# in the src/layers/conv.jl file. The `GNNlib` module is re-exported in the GraphNeuralNetworks module. +# This annoying for the readability of the code, as the user has to look at two different files to understand +# the implementation of a single layer, +# but it is done for GraphNeuralNetworks.jl and GNNLux.jl to be able to share the same code. + @doc raw""" GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight]) diff --git a/GraphNeuralNetworks/src/layers/pool.jl b/GraphNeuralNetworks/src/layers/pool.jl index 493ef6715..a7d9ceaef 100644 --- a/GraphNeuralNetworks/src/layers/pool.jl +++ b/GraphNeuralNetworks/src/layers/pool.jl @@ -155,12 +155,6 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1) return Set2Set(lstm, n_iters) end -function initialstates(cell::LSTMCell) - h = zeros_like(cell.Wh, size(cell.Wh, 2)) - c = zeros_like(cell.Wh, size(cell.Wh, 2)) - return h, c -end - function (l::Set2Set)(g, x) return GNNlib.set2set_pool(l, g, x) end diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index 67e85356d..47584bc89 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -1,578 +1,188 @@ -struct TGCNCell{C,G} <: GNNLayer - conv::C - gru::G - din::Int - dout::Int +function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T} + y = [] + for x_t in eachslice(x, dims = 2) + yt, state = cell(g, x_t, state) + y = vcat(y, [yt]) + end + return stack(y, dims = 2) end -Flux.@layer TGCNCell +""" + GConvGRUCell(in => out, k; [bias, init]) + +Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent cell from the paper +[Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/abs/1612.07659). + +Uses [`ChebConv`](@ref) to model spatial dependencies, +followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in => out`: A pair where `in` is the number of input node features and `out` + is the number of output node features. +- `k`: Chebyshev polynomial order. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. + +# Forward + + cell(g::GNNGraph, x, [h]) + +- `g`: The input graph. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. + If not provided, it is assumed to be a matrix of zeros. + +Performs one recurrence step and returns a tuple `(h, h)`, +where `h` is the updated hidden state of the GRU cell. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks, Flux + +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); -function TGCNCell(ch::Pair{Int, Int}; - bias::Bool = true, - init = Flux.glorot_uniform, - add_self_loops = false) - din, dout = ch - conv = GCNConv(din => dout, sigmoid; init, bias, add_self_loops) - gru = GRUCell(dout => dout) - return TGCNCell(conv, gru, din, dout) +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; + +julia> cell = GConvGRUCell(d_in => d_out, 2); + +julia> state = Flux.initialstates(cell); + +julia> y = state; + +julia> for xt in x + y, state = cell(g, xt, state) + end + +julia> size(y) # (d_out, num_nodes) +(3, 5) +``` +""" +@concrete struct GConvGRUCell <: GNNLayer + conv_x_r + conv_h_r + conv_x_z + conv_h_z + conv_x_h + conv_h_h + k::Int + in::Int + out::Int end -initialstates(cell::GRUCell) = zeros_like(cell.Wh, size(cell.Wh, 2)) -initialstates(cell::TGCNCell) = initialstates(cell.gru) -(cell::TGCNCell)(g::GNNGraph, x::AbstractVecOrMat) = cell(g, x, initialstates(cell)) +Flux.@layer :noexpand GConvGRUCell + +function GConvGRUCell(ch::Pair{Int, Int}, k::Int; + bias::Bool = true, + init = Flux.glorot_uniform, + ) + in, out = ch + # reset gate + conv_x_r = ChebConv(in => out, k; bias, init) + conv_h_r = ChebConv(out => out, k; bias, init) + # update gate + conv_x_z = ChebConv(in => out, k; bias, init) + conv_h_z = ChebConv(out => out, k; bias, init) + # new gate + conv_x_h = ChebConv(in => out, k; bias, init) + conv_h_h = ChebConv(out => out, k; bias, init) + return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, in, out) +end -function (cell::TGCNCell)(g::GNNGraph, x::AbstractVecOrMat, h::AbstractVecOrMat) - x = cell.conv(g, x) - h = cell.gru(x, h) - return h +function Flux.initialstates(cell::GConvGRUCell) + zeros_like(cell.conv_x_r.weight, cell.out) end -function Base.show(io::IO, cell::TGCNCell) - print(io, "TGCNCell($(cell.din) => $(cell.dout))") +(cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) + +function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) + h = repeat(h, 1, g.num_nodes) + return cell(g, x, h) end -""" - TGCN(din => dout; [bias, init, add_self_loops]) +function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + # reset gate + r = cell.conv_x_r(g, x) .+ cell.conv_h_r(g, h) + r = Flux.sigmoid_fast(r) + # update gate + z = cell.conv_x_z(g, x) .+ cell.conv_h_z(g, h) + z = Flux.sigmoid_fast(z) + # new gate + h̃ = cell.conv_x_h(g, x) .+ cell.conv_h_h(g, r .* h) + h̃ = Flux.tanh_fast(h̃) + h = (1 .- z) .* h̃ .+ z .* h + return h, h +end -Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). +function Base.show(io::IO, cell::GConvGRUCell) + print(io, "GConvGRUCell($(cell.in) => $(cell.out), $(cell.k))") +end -Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. +""" + GConvGRU(in => out, k; kws...) -# Arguments +The recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell, +used to process an entire temporal sequence of node features at once. -- `din`: Number of input features. -- `dout`: Number of output features. -- `bias`: Add learnable bias. Default `true`. -- `init`: Convolution's weights initializer. Default `glorot_uniform`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. +The arguments are the same as for [`GConvGRUCell`](@ref). # Forward - tgcn(g::GNNGraph, x, [h]) + layer(g::GNNGraph, x, [h]) - `g`: The input graph. -- `x`: The input to the TGCN. It should be a matrix size `din x timesteps` or an array of size `din x timesteps x num_nodes`. -- `h`: The initial hidden state of the GRU cell. If given, it is a vector of size `out` or a matrix of size `dout x num_nodes`. - If not provided, it is assumed to be a vector of zeros. +- `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`. +- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. + If not provided, it is assumed to be a matrix of zeros. + +Applies the recurrent cell to each timestep of the input sequence and returns the output as +an array of size `out x timesteps x num_nodes`. # Examples ```jldoctest -julia> din, dout = 2, 3; +julia> num_nodes, num_edges = 5, 10; -julia> tgcn = TGCN(din => dout) -TGCN( - TGCNCell( - GCNConv(2 => 3, σ), # 9 parameters - GRUCell(3 => 3), # 63 parameters - ), -) # Total: 5 arrays, 72 parameters, 560 bytes. +julia> d_in, d_out = 2, 3; -julia> num_nodes = 5; num_edges = 10; timesteps = 4; +julia> timesteps = 5; julia> g = rand_graph(num_nodes, num_edges); -julia> x = rand(Float32, din, timesteps, num_nodes); +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> layer = GConvGRU(d_in => d_out, 2); + +julia> y = layer(g, x); -julia> tgcn(g, x) |> size -(3, 4, 5) +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) ``` -""" -struct TGCN{C<:TGCNCell} <: GNNLayer - cell::C +""" +struct GConvGRU{G <: GConvGRUCell} <: GNNLayer + cell::G end -Flux.@layer TGCN +Flux.@layer GConvGRU -TGCN(ch::Pair{Int, Int}; kws...) = TGCN(TGCNCell(ch; kws...)) +function GConvGRU(ch::Pair{Int,Int}, k::Int; kws...) + return GConvGRU(GConvGRUCell(ch, k; kws...)) +end -initialstates(tgcn::TGCN) = initialstates(tgcn.cell) +Flux.initialstates(rnn::GConvGRU) = Flux.initialstates(rnn.cell) -(tgcn::TGCN)(g::GNNGraph, x) = tgcn(g, x, initialstates(tgcn)) +function (rnn::GConvGRU)(g::GNNGraph, x::AbstractArray) + return scan(rnn.cell, g, x, initialstates(rnn)) +end -function (tgcn::TGCN)(g::GNNGraph, x::AbstractArray, h) - @assert ndims(x) == 2 || ndims(x) == 3 - # [x] = [din, timesteps] or [din, timesteps, num_nodes] - # y = AbstractArray[] # issue https://github.com/JuliaLang/julia/issues/56771 - y = [] - for xt in eachslice(x, dims = 2) - h = tgcn.cell(g, xt, h) - y = vcat(y, [h]) - end - return stack(y, dims = 2) # [dout, timesteps, num_nodes] +function Base.show(io::IO, rnn::GConvGRU) + print(io, "GConvGRU($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))") end -Base.show(io::IO, tgcn::TGCN) = print(io, "TGCN($(tgcn.cell.din) => $(tgcn.cell.dout))") - -######## TO BE PORTED TO FLUX v0.15 from here ############################ - -# """ -# A3TGCN(din => dout; [bias, init, add_self_loops]) - -# Attention Temporal Graph Convolutional Network (A3T-GCN) model from the paper [A3T-GCN: Attention Temporal Graph -# Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583.pdf). - -# Performs a TGCN layer, followed by a soft attention layer. - -# # Arguments - -# - `din`: Number of input features. -# - `dout`: Number of output features. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Convolution's weights initializer. Default `glorot_uniform`. -# - `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. - -# # Examples - -# ```jldoctest -# julia> din, dout = 2, 3; - -# julia> model = A3TGCN(din => dout) -# TGCN( -# TGCNCell( -# GCNConv(2 => 3, σ), # 9 parameters -# GRUCell(3 => 3), # 63 parameters -# ), -# ) # Total: 5 arrays, 72 parameters, 560 bytes. - -# julia> num_nodes = 5; num_edges = 10; timesteps = 4; - -# julia> g = rand_graph(num_nodes, num_edges); - -# julia> x = rand(Float32, din, timesteps, num_nodes); - -# julia> model(g, x) |> size -# (3, 4, 5) -# ``` - -# !!! warning "Batch size changes" -# Failing to call `reset!` when the input batch size changes can lead to unexpected behavior. -# """ -# struct A3TGCN <: GNNLayer -# tgcn::TGCN -# dense1::Dense -# dense2::Dense -# din::Int -# dout::Int -# end - -# Flux.@layer A3TGCN - -# function A3TGCN(ch::Pair{Int, Int}, -# bias::Bool = true, -# init = Flux.glorot_uniform, -# add_self_loops = false) -# din, dout = ch -# tgcn = TGCN(din => dout; bias, init, init_state, add_self_loops) -# dense1 = Dense(dout => dout) -# dense2 = Dense(dout => dout) -# return A3TGCN(tgcn, dense1, dense2, din, dout) -# end - -# function (a3tgcn::A3TGCN)(g::GNNGraph, x::AbstractArray, h) -# h = a3tgcn.tgcn(g, x, h) -# e = a3tgcn.dense1(h) # WHY NOT RELU? -# e = a3tgcn.dense2(e) -# a = softmax(e, dims = 2) -# c = sum(a .* h , dims = 2) -# if length(size(c)) == 3 -# c = dropdims(c, dims = 2) -# end -# return c -# end - -# function Base.show(io::IO, a3tgcn::A3TGCN) -# print(io, "A3TGCN($(a3tgcn.din) => $(a3tgcn.dout))") -# end - -# struct GConvGRUCell <: GNNLayer -# conv_x_r::ChebConv -# conv_h_r::ChebConv -# conv_x_z::ChebConv -# conv_h_z::ChebConv -# conv_x_h::ChebConv -# conv_h_h::ChebConv -# k::Int -# state0 -# in::Int -# out::Int -# end - -# Flux.@layer GConvGRUCell - -# function GConvGRUCell(ch::Pair{Int, Int}, k::Int, n::Int; -# bias::Bool = true, -# init = Flux.glorot_uniform, -# init_state = Flux.zeros32) -# in, out = ch -# # reset gate -# conv_x_r = ChebConv(in => out, k; bias, init) -# conv_h_r = ChebConv(out => out, k; bias, init) -# # update gate -# conv_x_z = ChebConv(in => out, k; bias, init) -# conv_h_z = ChebConv(out => out, k; bias, init) -# # new gate -# conv_x_h = ChebConv(in => out, k; bias, init) -# conv_h_h = ChebConv(out => out, k; bias, init) -# state0 = init_state(out, n) -# return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, state0, in, out) -# end - -# function (ggru::GConvGRUCell)(h, g::GNNGraph, x) -# r = ggru.conv_x_r(g, x) .+ ggru.conv_h_r(g, h) -# r = Flux.sigmoid_fast(r) -# z = ggru.conv_x_z(g, x) .+ ggru.conv_h_z(g, h) -# z = Flux.sigmoid_fast(z) -# h̃ = ggru.conv_x_h(g, x) .+ ggru.conv_h_h(g, r .* h) -# h̃ = Flux.tanh_fast(h̃) -# h = (1 .- z) .* h̃ .+ z .* h -# return h, h -# end - -# function Base.show(io::IO, ggru::GConvGRUCell) -# print(io, "GConvGRUCell($(ggru.in) => $(ggru.out))") -# end - -# """ -# GConvGRU(in => out, k, n; [bias, init, init_state]) - -# Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). - -# Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `k`: Chebyshev polynomial order. -# - `n`: Number of nodes in the graph. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -# julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes); - -# julia> y = ggru(g1, x1); - -# julia> size(y) -# (5, 5) - -# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -# julia> z = ggru(g2, x2); - -# julia> size(z) -# (5, 5, 30) -# ``` -# """ -# # GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...)) -# # Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0) - -# # (l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -# # _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x) -# # _applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g) - -# struct GConvLSTMCell <: GNNLayer -# conv_x_i::ChebConv -# conv_h_i::ChebConv -# w_i -# b_i -# conv_x_f::ChebConv -# conv_h_f::ChebConv -# w_f -# b_f -# conv_x_c::ChebConv -# conv_h_c::ChebConv -# w_c -# b_c -# conv_x_o::ChebConv -# conv_h_o::ChebConv -# w_o -# b_o -# k::Int -# state0 -# in::Int -# out::Int -# end - -# Flux.@layer GConvLSTMCell - -# function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int; -# bias::Bool = true, -# init = Flux.glorot_uniform, -# init_state = Flux.zeros32) -# in, out = ch -# # input gate -# conv_x_i = ChebConv(in => out, k; bias, init) -# conv_h_i = ChebConv(out => out, k; bias, init) -# w_i = init(out, 1) -# b_i = bias ? Flux.create_bias(w_i, true, out) : false -# # forget gate -# conv_x_f = ChebConv(in => out, k; bias, init) -# conv_h_f = ChebConv(out => out, k; bias, init) -# w_f = init(out, 1) -# b_f = bias ? Flux.create_bias(w_f, true, out) : false -# # cell state -# conv_x_c = ChebConv(in => out, k; bias, init) -# conv_h_c = ChebConv(out => out, k; bias, init) -# w_c = init(out, 1) -# b_c = bias ? Flux.create_bias(w_c, true, out) : false -# # output gate -# conv_x_o = ChebConv(in => out, k; bias, init) -# conv_h_o = ChebConv(out => out, k; bias, init) -# w_o = init(out, 1) -# b_o = bias ? Flux.create_bias(w_o, true, out) : false -# state0 = (init_state(out, n), init_state(out, n)) -# return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i, -# conv_x_f, conv_h_f, w_f, b_f, -# conv_x_c, conv_h_c, w_c, b_c, -# conv_x_o, conv_h_o, w_o, b_o, -# k, state0, in, out) -# end - -# function (gclstm::GConvLSTMCell)((h, c), g::GNNGraph, x) -# # input gate -# i = gclstm.conv_x_i(g, x) .+ gclstm.conv_h_i(g, h) .+ gclstm.w_i .* c .+ gclstm.b_i -# i = Flux.sigmoid_fast(i) -# # forget gate -# f = gclstm.conv_x_f(g, x) .+ gclstm.conv_h_f(g, h) .+ gclstm.w_f .* c .+ gclstm.b_f -# f = Flux.sigmoid_fast(f) -# # cell state -# c = f .* c .+ i .* Flux.tanh_fast(gclstm.conv_x_c(g, x) .+ gclstm.conv_h_c(g, h) .+ gclstm.w_c .* c .+ gclstm.b_c) -# # output gate -# o = gclstm.conv_x_o(g, x) .+ gclstm.conv_h_o(g, h) .+ gclstm.w_o .* c .+ gclstm.b_o -# o = Flux.sigmoid_fast(o) -# h = o .* Flux.tanh_fast(c) -# return (h,c), h -# end - -# function Base.show(io::IO, gclstm::GConvLSTMCell) -# print(io, "GConvLSTMCell($(gclstm.in) => $(gclstm.out))") -# end - -# """ -# GConvLSTM(in => out, k, n; [bias, init, init_state]) - -# Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). - -# Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `k`: Chebyshev polynomial order. -# - `n`: Number of nodes in the graph. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -# julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes); - -# julia> y = gclstm(g1, x1); - -# julia> size(y) -# (5, 5) - -# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -# julia> z = gclstm(g2, x2); - -# julia> size(z) -# (5, 5, 30) -# ``` -# """ -# # GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...)) -# # Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0) - -# # (l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -# # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x) -# # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g) - -# struct DCGRUCell -# in::Int -# out::Int -# state0 -# k::Int -# dconv_u::DConv -# dconv_r::DConv -# dconv_c::DConv -# end - -# Flux.@layer DCGRUCell - -# function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) -# in, out = ch -# dconv_u = DConv((in + out) => out, k; bias=bias, init=init) -# dconv_r = DConv((in + out) => out, k; bias=bias, init=init) -# dconv_c = DConv((in + out) => out, k; bias=bias, init=init) -# state0 = init_state(out, n) -# return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c) -# end - -# function (dcgru::DCGRUCell)(h, g::GNNGraph, x) -# h̃ = vcat(x, h) -# z = dcgru.dconv_u(g, h̃) -# z = NNlib.sigmoid_fast.(z) -# r = dcgru.dconv_r(g, h̃) -# r = NNlib.sigmoid_fast.(r) -# ĥ = vcat(x, h .* r) -# c = dcgru.dconv_c(g, ĥ) -# c = tanh.(c) -# h = z.* h + (1 .- z) .* c -# return h, h -# end - -# function Base.show(io::IO, dcgru::DCGRUCell) -# print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))") -# end - -# """ -# DCGRU(in => out, k, n; [bias, init, init_state]) - -# Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural -# Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). - -# Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `k`: Diffusion step. -# - `n`: Number of nodes in the graph. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -# julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes); - -# julia> y = dcgru(g1, x1); - -# julia> size(y) -# (5, 5) - -# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -# julia> z = dcgru(g2, x2); - -# julia> size(z) -# (5, 5, 30) -# ``` -# """ -# # DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...)) -# # Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) - -# # (l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -# # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) -# # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) - -# """ -# EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) - -# Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). - -# Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. - - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) -# TemporalSnapshotsGNNGraph: -# num_nodes: [10, 10, 10] -# num_edges: [20, 14, 22] -# num_snapshots: 3 - -# julia> ev = EvolveGCNO(4 => 5) -# EvolveGCNO(4 => 5) - -# julia> size(ev(tg, tg.ndata.x)) -# (3,) - -# julia> size(ev(tg, tg.ndata.x)[1]) -# (5, 10) -# ``` -# """ -# struct EvolveGCNO -# conv -# W_init -# init_state -# in::Int -# out::Int -# Wf -# Uf -# Bf -# Wi -# Ui -# Bi -# Wo -# Uo -# Bo -# Wc -# Uc -# Bc -# end - -# function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) -# in, out = ch -# W = init(out, in) -# conv = GCNConv(ch; bias = bias, init = init) -# Wf = init(out, in) -# Uf = init(out, in) -# Bf = bias ? init(out, in) : nothing -# Wi = init(out, in) -# Ui = init(out, in) -# Bi = bias ? init(out, in) : nothing -# Wo = init(out, in) -# Uo = init(out, in) -# Bo = bias ? init(out, in) : nothing -# Wc = init(out, in) -# Uc = init(out, in) -# Bc = bias ? init(out, in) : nothing -# return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) -# end - -# function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) -# H = egcno.init_state(egcno.out, egcno.in) -# C = egcno.init_state(egcno.out, egcno.in) -# W = egcno.W_init -# X = map(1:tg.num_snapshots) do i -# F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) -# I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) -# O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) -# C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) -# C = F .* C + I .* C̃ -# H = O .* tanh_fast.(C) -# W = H -# egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) -# end -# return X -# end - -# function Base.show(io::IO, egcno::EvolveGCNO) -# print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") -# end diff --git a/GraphNeuralNetworks/src/layers/temporalconv_old.jl b/GraphNeuralNetworks/src/layers/temporalconv_old.jl new file mode 100644 index 000000000..bf1f7ef32 --- /dev/null +++ b/GraphNeuralNetworks/src/layers/temporalconv_old.jl @@ -0,0 +1,589 @@ +function scan(cell, g, x, state) + y = [] + for x_t in eachslice(x, dims = 2) + yt, state = cell(g, x_t, state) + y = vcat(y, [yt]) + end + return stack(y, dims = 2) +end + + +struct TGCNCell{C,G} <: GNNLayer + conv::C + gru::G + in::Int + out::Int +end + +Flux.@layer :noexpand TGCNCell + +function TGCNCell(ch::Pair{Int, Int}; + bias::Bool = true, + init = Flux.glorot_uniform, + add_self_loops = false) + in, out = ch + conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops) + gru = GRUCell(out => out) + return TGCNCell(conv, gru, in, out) +end + +Flux.initialstates(cell::TGCNCell) = initialstates(cell.gru) +(cell::TGCNCell)(g::GNNGraph, x::AbstractVecOrMat) = cell(g, x, initialstates(cell)) + +function (cell::TGCNCell)(g::GNNGraph, x::AbstractVecOrMat, h::AbstractVecOrMat) + x = cell.conv(g, x) + h = cell.gru(x, h) + return h +end + +function Base.show(io::IO, cell::TGCNCell) + print(io, "TGCNCell($(cell.in) => $(cell.out))") +end + +""" + TGCN(in => out; [bias, init, add_self_loops]) + +Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). + +Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `bias`: Add learnable bias. Default `true`. +- `init`: Convolution's weights initializer. Default `glorot_uniform`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. + +# Forward + + tgcn(g::GNNGraph, x, [h]) + +- `g`: The input graph. +- `x`: The input to the TGCN. It should be a matrix size `in x timesteps` or an array of size `in x timesteps x num_nodes`. +- `h`: The initial hidden state of the GRU cell. If given, it is a vector of size `out` or a matrix of size `out x num_nodes`. + If not provided, it is assumed to be a vector of zeros. + +# Examples + +```jldoctest +julia> in, out = 2, 3; + +julia> tgcn = TGCN(in => out) +TGCN( + TGCNCell( + GCNConv(2 => 3, σ), # 9 parameters + GRUCell(3 => 3), # 63 parameters + ), +) # Total: 5 arrays, 72 parameters, 560 bytes. + +julia> num_nodes = 5; num_edges = 10; timesteps = 4; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, in, timesteps, num_nodes); + +julia> tgcn(g, x) |> size +(3, 4, 5) +``` +""" +struct TGCN{C<:TGCNCell} <: GNNLayer + cell::C +end + +Flux.@layer TGCN + +TGCN(ch::Pair{Int, Int}; kws...) = TGCN(TGCNCell(ch; kws...)) + +initialstates(tgcn::TGCN) = initialstates(tgcn.cell) + +(tgcn::TGCN)(g::GNNGraph, x) = tgcn(g, x, initialstates(tgcn)) + +function (tgcn::TGCN)(g::GNNGraph, x::AbstractArray, h) + return scan(tgcn.cell, g, x, h) +end + +Base.show(io::IO, tgcn::TGCN) = print(io, "TGCN($(tgcn.cell.in) => $(tgcn.cell.out))") + +####### TO BE PORTED TO FLUX v0.15 from here ############################ + +""" + A3TGCN(in => out; [bias, init, add_self_loops]) + +Attention Temporal Graph Convolutional Network (A3T-GCN) model from the paper [A3T-GCN: Attention Temporal Graph +Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583.pdf). + +Performs a TGCN layer, followed by a soft attention layer. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `bias`: Add learnable bias. Default `true`. +- `init`: Convolution's weights initializer. Default `glorot_uniform`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. + +# Examples + +```jldoctest +julia> in, out = 2, 3; + +julia> model = A3TGCN(in => out) +TGCN( + TGCNCell( + GCNConv(2 => 3, σ), # 9 parameters + GRUCell(3 => 3), # 63 parameters + ), +) # Total: 5 arrays, 72 parameters, 560 bytes. + +julia> num_nodes = 5; num_edges = 10; timesteps = 4; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, in, timesteps, num_nodes); + +julia> model(g, x) |> size +(3, 4, 5) +``` +""" +struct A3TGCN <: GNNLayer + tgcn::TGCN + dense1::Dense + dense2::Dense + in::Int + out::Int +end + +Flux.@layer A3TGCN + +function A3TGCN(ch::Pair{Int, Int}; kws...) + in, out = ch + tgcn = TGCN(in => out; kws...) + dense1 = Dense(out => out) + dense2 = Dense(out => out) + return A3TGCN(tgcn, dense1, dense2, in, out) +end + +function (a3tgcn::A3TGCN)(g::GNNGraph, x::AbstractArray, h) + h = a3tgcn.tgcn(g, x, h) # [out, timesteps, num_nodes] + logits = a3tgcn.dense1(h) + logits = a3tgcn.dense2(logits) # [out, timesteps, num_nodes] + a = softmax(logits, dims=2) # TODO handle multiple graphs + c = sum(a .* h, dims=2) + c = dropdims(c, dims=2) # [out, num_nodes] + return c +end + +function Base.show(io::IO, a3tgcn::A3TGCN) + print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))") +end + +@concrete struct GConvGRUCell <: GNNLayer + conv_x_r + conv_h_r + conv_x_z + conv_h_z + conv_x_h + conv_h_h + k::Int + in::Int + out::Int +end + +Flux.@layer :noexpand GConvGRUCell + +function GConvGRUCell(ch::Pair{Int, Int}, k::Int; + bias::Bool = true, + init = Flux.glorot_uniform, + ) + in, out = ch + # reset gate + conv_x_r = ChebConv(in => out, k; bias, init) + conv_h_r = ChebConv(out => out, k; bias, init) + # update gate + conv_x_z = ChebConv(in => out, k; bias, init) + conv_h_z = ChebConv(out => out, k; bias, init) + # new gate + conv_x_h = ChebConv(in => out, k; bias, init) + conv_h_h = ChebConv(out => out, k; bias, init) + return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, in, out) +end + +function Flux.initialstates(cell::GConvGRUCell) + zeros_like(cell.conv_x_r.weight, cell.out) +end + +(cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) + +function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) + h = repeat(h, 1, g.num_nodes) + return cell(g, x, h) +end + +function (ggru::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + r = ggru.conv_x_r(g, x) .+ ggru.conv_h_r(g, h) + r = Flux.sigmoid_fast(r) + z = ggru.conv_x_z(g, x) .+ ggru.conv_h_z(g, h) + z = Flux.sigmoid_fast(z) + h̃ = ggru.conv_x_h(g, x) .+ ggru.conv_h_h(g, r .* h) + h̃ = Flux.tanh_fast(h̃) + h = (1 .- z) .* h̃ .+ z .* h + return h, h +end + +function Base.show(io::IO, ggru::GConvGRUCell) + print(io, "GConvGRUCell($(ggru.in) => $(ggru.out))") +end + +""" + GConvGRU(in => out, k, n; [bias, init, init_state]) + +Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper +[Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/abs/1612.07659). + +Uses [`ChebConv`](@ref) to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Chebyshev polynomial order. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. + +# Forward + + ggru(g::GNNGraph, x, [h]) + +- `g`: The input graph. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. + If not provided, it is assumed to be a matrix of zeros. + +# Examples + +```jldoctest +julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes); + +julia> y = ggru(g1, x1); + +julia> size(y) +(5, 5) + +julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +julia> z = ggru(g2, x2); + +julia> size(z) +(5, 5, 30) +``` +""" + +struct GConvGRU{G <: GConvGRUCell} <: GNNLayer + cell::G +end + +Flux.@layer GConvGRU + + +# struct GConvLSTMCell <: GNNLayer +# conv_x_i::ChebConv +# conv_h_i::ChebConv +# w_i +# b_i +# conv_x_f::ChebConv +# conv_h_f::ChebConv +# w_f +# b_f +# conv_x_c::ChebConv +# conv_h_c::ChebConv +# w_c +# b_c +# conv_x_o::ChebConv +# conv_h_o::ChebConv +# w_o +# b_o +# k::Int +# state0 +# in::Int +# out::Int +# end + +# Flux.@layer GConvLSTMCell + +# function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int; +# bias::Bool = true, +# init = Flux.glorot_uniform, +# init_state = Flux.zeros32) +# in, out = ch +# # input gate +# conv_x_i = ChebConv(in => out, k; bias, init) +# conv_h_i = ChebConv(out => out, k; bias, init) +# w_i = init(out, 1) +# b_i = bias ? Flux.create_bias(w_i, true, out) : false +# # forget gate +# conv_x_f = ChebConv(in => out, k; bias, init) +# conv_h_f = ChebConv(out => out, k; bias, init) +# w_f = init(out, 1) +# b_f = bias ? Flux.create_bias(w_f, true, out) : false +# # cell state +# conv_x_c = ChebConv(in => out, k; bias, init) +# conv_h_c = ChebConv(out => out, k; bias, init) +# w_c = init(out, 1) +# b_c = bias ? Flux.create_bias(w_c, true, out) : false +# # output gate +# conv_x_o = ChebConv(in => out, k; bias, init) +# conv_h_o = ChebConv(out => out, k; bias, init) +# w_o = init(out, 1) +# b_o = bias ? Flux.create_bias(w_o, true, out) : false +# state0 = (init_state(out, n), init_state(out, n)) +# return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i, +# conv_x_f, conv_h_f, w_f, b_f, +# conv_x_c, conv_h_c, w_c, b_c, +# conv_x_o, conv_h_o, w_o, b_o, +# k, state0, in, out) +# end + +# function (gclstm::GConvLSTMCell)((h, c), g::GNNGraph, x) +# # input gate +# i = gclstm.conv_x_i(g, x) .+ gclstm.conv_h_i(g, h) .+ gclstm.w_i .* c .+ gclstm.b_i +# i = Flux.sigmoid_fast(i) +# # forget gate +# f = gclstm.conv_x_f(g, x) .+ gclstm.conv_h_f(g, h) .+ gclstm.w_f .* c .+ gclstm.b_f +# f = Flux.sigmoid_fast(f) +# # cell state +# c = f .* c .+ i .* Flux.tanh_fast(gclstm.conv_x_c(g, x) .+ gclstm.conv_h_c(g, h) .+ gclstm.w_c .* c .+ gclstm.b_c) +# # output gate +# o = gclstm.conv_x_o(g, x) .+ gclstm.conv_h_o(g, h) .+ gclstm.w_o .* c .+ gclstm.b_o +# o = Flux.sigmoid_fast(o) +# h = o .* Flux.tanh_fast(c) +# return (h,c), h +# end + +# function Base.show(io::IO, gclstm::GConvLSTMCell) +# print(io, "GConvLSTMCell($(gclstm.in) => $(gclstm.out))") +# end + +# """ +# GConvLSTM(in => out, k, n; [bias, init, init_state]) + +# Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). + +# Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. + +# # Arguments + +# - `in`: Number of input features. +# - `out`: Number of output features. +# - `k`: Chebyshev polynomial order. +# - `n`: Number of nodes in the graph. +# - `bias`: Add learnable bias. Default `true`. +# - `init`: Weights' initializer. Default `glorot_uniform`. +# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# # Examples + +# ```jldoctest +# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +# julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes); + +# julia> y = gclstm(g1, x1); + +# julia> size(y) +# (5, 5) + +# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +# julia> z = gclstm(g2, x2); + +# julia> size(z) +# (5, 5, 30) +# ``` +# """ +# # GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...)) +# # Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0) + +# # (l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +# # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x) +# # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g) + +# struct DCGRUCell +# in::Int +# out::Int +# state0 +# k::Int +# dconv_u::DConv +# dconv_r::DConv +# dconv_c::DConv +# end + +# Flux.@layer DCGRUCell + +# function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) +# in, out = ch +# dconv_u = DConv((in + out) => out, k; bias=bias, init=init) +# dconv_r = DConv((in + out) => out, k; bias=bias, init=init) +# dconv_c = DConv((in + out) => out, k; bias=bias, init=init) +# state0 = init_state(out, n) +# return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c) +# end + +# function (dcgru::DCGRUCell)(h, g::GNNGraph, x) +# h̃ = vcat(x, h) +# z = dcgru.dconv_u(g, h̃) +# z = NNlib.sigmoid_fast.(z) +# r = dcgru.dconv_r(g, h̃) +# r = NNlib.sigmoid_fast.(r) +# ĥ = vcat(x, h .* r) +# c = dcgru.dconv_c(g, ĥ) +# c = tanh.(c) +# h = z.* h + (1 .- z) .* c +# return h, h +# end + +# function Base.show(io::IO, dcgru::DCGRUCell) +# print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))") +# end + +# """ +# DCGRU(in => out, k, n; [bias, init, init_state]) + +# Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural +# Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). + +# Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# # Arguments + +# - `in`: Number of input features. +# - `out`: Number of output features. +# - `k`: Diffusion step. +# - `n`: Number of nodes in the graph. +# - `bias`: Add learnable bias. Default `true`. +# - `init`: Weights' initializer. Default `glorot_uniform`. +# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# # Examples + +# ```jldoctest +# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +# julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes); + +# julia> y = dcgru(g1, x1); + +# julia> size(y) +# (5, 5) + +# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +# julia> z = dcgru(g2, x2); + +# julia> size(z) +# (5, 5, 30) +# ``` +# """ +# # DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...)) +# # Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) + +# # (l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +# # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) +# # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) + +# """ +# EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + +# Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). + +# Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. + + +# # Arguments + +# - `in`: Number of input features. +# - `out`: Number of output features. +# - `bias`: Add learnable bias. Default `true`. +# - `init`: Weights' initializer. Default `glorot_uniform`. +# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# # Examples + +# ```jldoctest +# julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) +# TemporalSnapshotsGNNGraph: +# num_nodes: [10, 10, 10] +# num_edges: [20, 14, 22] +# num_snapshots: 3 + +# julia> ev = EvolveGCNO(4 => 5) +# EvolveGCNO(4 => 5) + +# julia> size(ev(tg, tg.ndata.x)) +# (3,) + +# julia> size(ev(tg, tg.ndata.x)[1]) +# (5, 10) +# ``` +# """ +# struct EvolveGCNO +# conv +# W_init +# init_state +# in::Int +# out::Int +# Wf +# Uf +# Bf +# Wi +# Ui +# Bi +# Wo +# Uo +# Bo +# Wc +# Uc +# Bc +# end + +# function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) +# in, out = ch +# W = init(out, in) +# conv = GCNConv(ch; bias = bias, init = init) +# Wf = init(out, in) +# Uf = init(out, in) +# Bf = bias ? init(out, in) : nothing +# Wi = init(out, in) +# Ui = init(out, in) +# Bi = bias ? init(out, in) : nothing +# Wo = init(out, in) +# Uo = init(out, in) +# Bo = bias ? init(out, in) : nothing +# Wc = init(out, in) +# Uc = init(out, in) +# Bc = bias ? init(out, in) : nothing +# return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) +# end + +# function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) +# H = egcno.init_state(egcno.out, egcno.in) +# C = egcno.init_state(egcno.out, egcno.in) +# W = egcno.W_init +# X = map(1:tg.num_snapshots) do i +# F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) +# I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) +# O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) +# C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) +# C = F .* C + I .* C̃ +# H = O .* tanh_fast.(C) +# W = H +# egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) +# end +# return X +# end + +# function Base.show(io::IO, egcno::EvolveGCNO) +# print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") +# end From 02919acf7bc78a41253477e22641d33b801d464c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 16 Dec 2024 00:56:28 +0100 Subject: [PATCH 02/11] GConvLSTM --- .../src/GraphNeuralNetworks.jl | 2 +- .../src/layers/temporalconv.jl | 211 +++++++++++++++++- .../src/layers/temporalconv_old.jl | 185 --------------- 3 files changed, 211 insertions(+), 187 deletions(-) diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index 10ff5f310..4f27652f9 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -51,9 +51,9 @@ export HeteroGraphConv include("layers/temporalconv.jl") export GConvGRU, GConvGRUCell, + GConvLSTM, GConvLSTMCell, TGCN, A3TGCN, - GConvLSTM, DCGRU, EvolveGCNO diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index 47584bc89..4e3fe0108 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -166,7 +166,7 @@ julia> size(y) # (d_out, timesteps, num_nodes) (3, 5, 5) ``` """ -struct GConvGRU{G <: GConvGRUCell} <: GNNLayer +struct GConvGRU{G<:GConvGRUCell} <: GNNLayer cell::G end @@ -186,3 +186,212 @@ function Base.show(io::IO, rnn::GConvGRU) print(io, "GConvGRU($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))") end + +""" + GConvLSTMCell(in => out, k; [bias, init]) + +Graph Convolutional LSTM recurrent cell from the paper +[Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/abs/1612.07659). + +Uses [`ChebConv`](@ref) to model spatial dependencies, +followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. + +# Arguments + +- `in => out`: A pair where `in` is the number of input node features and `out` + is the number of output node features. +- `k`: Chebyshev polynomial order. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. + +# Forward + + cell(g::GNNGraph, x, [state]) + +- `g`: The input graph. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `state`: The initial hidden state of the LSTM cell. + If given, it is a tuple `(h, c)` where both `h` and `c` are arrays of size `out x num_nodes`. + If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros. + +Performs one recurrence step and returns a tuple `(output, state)`, +where `output` is the updated hidden state `h` of the LSTM cell and `state` is the updated tuple `(h, c)`. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks, Flux + +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; + +julia> cell = GConvLSTMCell(d_in => d_out, 2); + +julia> state = Flux.initialstates(cell); + +julia> y = state[1]; + +julia> for xt in x + y, state = cell(g, xt, state) + end + +julia> size(y) # (d_out, num_nodes) +(3, 5) +``` +""" +@concrete struct GConvLSTMCell <: GNNLayer + conv_x_i + conv_h_i + w_i + b_i + conv_x_f + conv_h_f + w_f + b_f + conv_x_c + conv_h_c + w_c + b_c + conv_x_o + conv_h_o + w_o + b_o + k::Int + in::Int + out::Int +end + +Flux.@layer GConvLSTMCell + +function GConvLSTMCell(ch::Pair{Int, Int}, k::Int; + bias::Bool = true, + init = Flux.glorot_uniform) + in, out = ch + # input gate + conv_x_i = ChebConv(in => out, k; bias, init) + conv_h_i = ChebConv(out => out, k; bias, init) + w_i = init(out, 1) + b_i = bias ? Flux.create_bias(w_i, true, out) : false + # forget gate + conv_x_f = ChebConv(in => out, k; bias, init) + conv_h_f = ChebConv(out => out, k; bias, init) + w_f = init(out, 1) + b_f = bias ? Flux.create_bias(w_f, true, out) : false + # cell state + conv_x_c = ChebConv(in => out, k; bias, init) + conv_h_c = ChebConv(out => out, k; bias, init) + w_c = init(out, 1) + b_c = bias ? Flux.create_bias(w_c, true, out) : false + # output gate + conv_x_o = ChebConv(in => out, k; bias, init) + conv_h_o = ChebConv(out => out, k; bias, init) + w_o = init(out, 1) + b_o = bias ? Flux.create_bias(w_o, true, out) : false + return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i, + conv_x_f, conv_h_f, w_f, b_f, + conv_x_c, conv_h_c, w_c, b_c, + conv_x_o, conv_h_o, w_o, b_o, + k, in, out) +end + +function Flux.initialstates(cell::GConvLSTMCell) + (zeros_like(cell.conv_x_i.weight, cell.out), zeros_like(cell.conv_x_i.weight, cell.out)) +end + +function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c)) + if h isa AbstractVector + h = repeat(h, 1, g.num_nodes) + end + if c isa AbstractVector + c = repeat(c, 1, g.num_nodes) + end + @assert ndims(h) == 2 && ndims(c) == 2 + # input gate + i = cell.conv_x_i(g, x) .+ cell.conv_h_i(g, h) .+ cell.w_i .* c .+ cell.b_i + i = Flux.sigmoid_fast(i) + # forget gate + f = cell.conv_x_f(g, x) .+ cell.conv_h_f(g, h) .+ cell.w_f .* c .+ cell.b_f + f = Flux.sigmoid_fast(f) + # cell state + c = f .* c .+ i .* Flux.tanh_fast(cell.conv_x_c(g, x) .+ cell.conv_h_c(g, h) .+ cell.w_c .* c .+ cell.b_c) + # output gate + o = cell.conv_x_o(g, x) .+ cell.conv_h_o(g, h) .+ cell.w_o .* c .+ cell.b_o + o = Flux.sigmoid_fast(o) + h = o .* Flux.tanh_fast(c) + return h, (h, c) +end + +function Base.show(io::IO, cell::GConvLSTMCell) + print(io, "GConvLSTMCell($(cell.in) => $(cell.out), $(cell.k))") +end + + +""" + GConvLSTM(in => out, k; kws...) + +The recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell, +used to process an entire temporal sequence of node features at once. + +The arguments are the same as for [`GConvLSTMCell`](@ref). + +# Forward + + layer(g::GNNGraph, x, [state]) + +- `g`: The input graph. +- `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`. +- `state`: The initial hidden state of the LSTM cell. + If given, it is a tuple `(h, c)` where both elements are matrices of size `out x num_nodes`. + If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros. + +Applies the recurrent cell to each timestep of the input sequence and returns the output as +an array of size `out x timesteps x num_nodes`. + +# Examples + +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> layer = GConvLSTM(d_in => d_out, 2); + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +""" +struct GConvLSTM{G<:GConvLSTMCell} <: GNNLayer + cell::G +end + +Flux.@layer GConvLSTM + +function GConvLSTM(ch::Pair{Int,Int}, k::Int; kws...) + return GConvLSTM(GConvLSTMCell(ch, k; kws...)) +end + +Flux.initialstates(rnn::GConvLSTM) = Flux.initialstates(rnn.cell) + +function (rnn::GConvLSTM)(g::GNNGraph, x::AbstractArray) + return scan(rnn.cell, g, x, initialstates(rnn)) +end + +function Base.show(io::IO, rnn::GConvLSTM) + print(io, "GConvLSTM($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))") +end + diff --git a/GraphNeuralNetworks/src/layers/temporalconv_old.jl b/GraphNeuralNetworks/src/layers/temporalconv_old.jl index bf1f7ef32..634911f28 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv_old.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv_old.jl @@ -178,193 +178,8 @@ function Base.show(io::IO, a3tgcn::A3TGCN) print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))") end -@concrete struct GConvGRUCell <: GNNLayer - conv_x_r - conv_h_r - conv_x_z - conv_h_z - conv_x_h - conv_h_h - k::Int - in::Int - out::Int -end - -Flux.@layer :noexpand GConvGRUCell - -function GConvGRUCell(ch::Pair{Int, Int}, k::Int; - bias::Bool = true, - init = Flux.glorot_uniform, - ) - in, out = ch - # reset gate - conv_x_r = ChebConv(in => out, k; bias, init) - conv_h_r = ChebConv(out => out, k; bias, init) - # update gate - conv_x_z = ChebConv(in => out, k; bias, init) - conv_h_z = ChebConv(out => out, k; bias, init) - # new gate - conv_x_h = ChebConv(in => out, k; bias, init) - conv_h_h = ChebConv(out => out, k; bias, init) - return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, in, out) -end - -function Flux.initialstates(cell::GConvGRUCell) - zeros_like(cell.conv_x_r.weight, cell.out) -end - -(cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) - -function (cell::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) - h = repeat(h, 1, g.num_nodes) - return cell(g, x, h) -end - -function (ggru::GConvGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) - r = ggru.conv_x_r(g, x) .+ ggru.conv_h_r(g, h) - r = Flux.sigmoid_fast(r) - z = ggru.conv_x_z(g, x) .+ ggru.conv_h_z(g, h) - z = Flux.sigmoid_fast(z) - h̃ = ggru.conv_x_h(g, x) .+ ggru.conv_h_h(g, r .* h) - h̃ = Flux.tanh_fast(h̃) - h = (1 .- z) .* h̃ .+ z .* h - return h, h -end - -function Base.show(io::IO, ggru::GConvGRUCell) - print(io, "GConvGRUCell($(ggru.in) => $(ggru.out))") -end - -""" - GConvGRU(in => out, k, n; [bias, init, init_state]) - -Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper -[Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/abs/1612.07659). - -Uses [`ChebConv`](@ref) to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `k`: Chebyshev polynomial order. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. - -# Forward - ggru(g::GNNGraph, x, [h]) -- `g`: The input graph. -- `x`: The node features. It should be a matrix of size `in x num_nodes`. -- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. - If not provided, it is assumed to be a matrix of zeros. - -# Examples - -```jldoctest -julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes); - -julia> y = ggru(g1, x1); - -julia> size(y) -(5, 5) - -julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -julia> z = ggru(g2, x2); - -julia> size(z) -(5, 5, 30) -``` -""" - -struct GConvGRU{G <: GConvGRUCell} <: GNNLayer - cell::G -end - -Flux.@layer GConvGRU - - -# struct GConvLSTMCell <: GNNLayer -# conv_x_i::ChebConv -# conv_h_i::ChebConv -# w_i -# b_i -# conv_x_f::ChebConv -# conv_h_f::ChebConv -# w_f -# b_f -# conv_x_c::ChebConv -# conv_h_c::ChebConv -# w_c -# b_c -# conv_x_o::ChebConv -# conv_h_o::ChebConv -# w_o -# b_o -# k::Int -# state0 -# in::Int -# out::Int -# end - -# Flux.@layer GConvLSTMCell - -# function GConvLSTMCell(ch::Pair{Int, Int}, k::Int, n::Int; -# bias::Bool = true, -# init = Flux.glorot_uniform, -# init_state = Flux.zeros32) -# in, out = ch -# # input gate -# conv_x_i = ChebConv(in => out, k; bias, init) -# conv_h_i = ChebConv(out => out, k; bias, init) -# w_i = init(out, 1) -# b_i = bias ? Flux.create_bias(w_i, true, out) : false -# # forget gate -# conv_x_f = ChebConv(in => out, k; bias, init) -# conv_h_f = ChebConv(out => out, k; bias, init) -# w_f = init(out, 1) -# b_f = bias ? Flux.create_bias(w_f, true, out) : false -# # cell state -# conv_x_c = ChebConv(in => out, k; bias, init) -# conv_h_c = ChebConv(out => out, k; bias, init) -# w_c = init(out, 1) -# b_c = bias ? Flux.create_bias(w_c, true, out) : false -# # output gate -# conv_x_o = ChebConv(in => out, k; bias, init) -# conv_h_o = ChebConv(out => out, k; bias, init) -# w_o = init(out, 1) -# b_o = bias ? Flux.create_bias(w_o, true, out) : false -# state0 = (init_state(out, n), init_state(out, n)) -# return GConvLSTMCell(conv_x_i, conv_h_i, w_i, b_i, -# conv_x_f, conv_h_f, w_f, b_f, -# conv_x_c, conv_h_c, w_c, b_c, -# conv_x_o, conv_h_o, w_o, b_o, -# k, state0, in, out) -# end - -# function (gclstm::GConvLSTMCell)((h, c), g::GNNGraph, x) -# # input gate -# i = gclstm.conv_x_i(g, x) .+ gclstm.conv_h_i(g, h) .+ gclstm.w_i .* c .+ gclstm.b_i -# i = Flux.sigmoid_fast(i) -# # forget gate -# f = gclstm.conv_x_f(g, x) .+ gclstm.conv_h_f(g, h) .+ gclstm.w_f .* c .+ gclstm.b_f -# f = Flux.sigmoid_fast(f) -# # cell state -# c = f .* c .+ i .* Flux.tanh_fast(gclstm.conv_x_c(g, x) .+ gclstm.conv_h_c(g, h) .+ gclstm.w_c .* c .+ gclstm.b_c) -# # output gate -# o = gclstm.conv_x_o(g, x) .+ gclstm.conv_h_o(g, h) .+ gclstm.w_o .* c .+ gclstm.b_o -# o = Flux.sigmoid_fast(o) -# h = o .* Flux.tanh_fast(c) -# return (h,c), h -# end - -# function Base.show(io::IO, gclstm::GConvLSTMCell) -# print(io, "GConvLSTMCell($(gclstm.in) => $(gclstm.out))") -# end # """ # GConvLSTM(in => out, k, n; [bias, init, init_state]) From cfd9ec32111bd620f7a5a1807cba54de290fd3ef Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 16 Dec 2024 08:59:26 +0100 Subject: [PATCH 03/11] GNNRecurrence --- .../src/GraphNeuralNetworks.jl | 5 +- .../src/layers/temporalconv.jl | 281 ++++++++++++++---- .../src/layers/temporalconv_old.jl | 129 -------- 3 files changed, 220 insertions(+), 195 deletions(-) diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index 4f27652f9..fd504f751 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -50,11 +50,12 @@ include("layers/heteroconv.jl") export HeteroGraphConv include("layers/temporalconv.jl") -export GConvGRU, GConvGRUCell, +export GNNRecurrence, + GConvGRU, GConvGRUCell, GConvLSTM, GConvLSTMCell, + DCGRU, DCGRUCell, TGCN, A3TGCN, - DCGRU, EvolveGCNO include("layers/pool.jl") diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index 4e3fe0108..8c57cfa0f 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -7,6 +7,73 @@ function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T} return stack(y, dims = 2) end + +""" + GNNRecurrence(cell) + +Construct a recurrent layer that applies the `cell` +to process an entire temporal sequence of node features at once. + +# Forward + + layer(g::GNNGraph, x, [state]) + +- `g`: The input graph. +- `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`. +- `state`: The initial state of the cell. + If not provided, it is generated by calling `Flux.initialstates(cell)`. + +Applies the recurrent cell to each timestep of the input sequence and returns the output as +an array of size `out_features x timesteps x num_nodes`. + +# Examples + +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> cell = GConvLSTMCell(d_in => d_out, 2) +GConvLSTMCell(2 => 3, 2) # 168 parameters + +julia> layer = GNNRecurrence(cell) +GNNRecurrence( + GConvLSTMCell(2 => 3, 2), # 168 parameters +) # Total: 24 arrays, 168 parameters, 2.023 KiB. + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +""" +struct GNNRecurrence{G} <: GNNLayer + cell::G +end + +Flux.@layer GNNRecurrence + +Flux.initialstates(rnn::GNNRecurrence) = Flux.initialstates(rnn.cell) + +function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}) where {T} + return rnn(g, x, initialstates(rnn)) +end + +function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}, state) where {T} + return scan(rnn.cell, g, x, state) +end + +function Base.show(io::IO, rnn::GNNRecurrence) + print(io, "GNNRecurrence($(rnn.cell))") +end + + """ GConvGRUCell(in => out, k; [bias, init]) @@ -126,24 +193,13 @@ function Base.show(io::IO, cell::GConvGRUCell) end """ - GConvGRU(in => out, k; kws...) - -The recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell, -used to process an entire temporal sequence of node features at once. + GConvGRU(args...; kws...) -The arguments are the same as for [`GConvGRUCell`](@ref). +Construct a recurrent layer corresponding to the [`GConvGRUCell`](@ref) cell. +It can be used to process an entire temporal sequence of node features at once. -# Forward - - layer(g::GNNGraph, x, [h]) - -- `g`: The input graph. -- `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`. -- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. - If not provided, it is assumed to be a matrix of zeros. - -Applies the recurrent cell to each timestep of the input sequence and returns the output as -an array of size `out x timesteps x num_nodes`. +The arguments are passed to the [`GConvGRUCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. # Examples @@ -158,7 +214,10 @@ julia> g = rand_graph(num_nodes, num_edges); julia> x = rand(Float32, d_in, timesteps, num_nodes); -julia> layer = GConvGRU(d_in => d_out, 2); +julia> layer = GConvGRU(d_in => d_out, 2) +GConvGRU( + GConvGRUCell(2 => 3, 2), # 108 parameters +) # Total: 12 arrays, 108 parameters, 1.148 KiB. julia> y = layer(g, x); @@ -166,25 +225,7 @@ julia> size(y) # (d_out, timesteps, num_nodes) (3, 5, 5) ``` """ -struct GConvGRU{G<:GConvGRUCell} <: GNNLayer - cell::G -end - -Flux.@layer GConvGRU - -function GConvGRU(ch::Pair{Int,Int}, k::Int; kws...) - return GConvGRU(GConvGRUCell(ch, k; kws...)) -end - -Flux.initialstates(rnn::GConvGRU) = Flux.initialstates(rnn.cell) - -function (rnn::GConvGRU)(g::GNNGraph, x::AbstractArray) - return scan(rnn.cell, g, x, initialstates(rnn)) -end - -function Base.show(io::IO, rnn::GConvGRU) - print(io, "GConvGRU($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))") -end +GConvGRU(args...; kws...) = GNNRecurrence(GConvGRUCell(args...; kws...)) """ @@ -268,7 +309,7 @@ julia> size(y) # (d_out, num_nodes) out::Int end -Flux.@layer GConvLSTMCell +Flux.@layer :noexpand GConvLSTMCell function GConvLSTMCell(ch::Pair{Int, Int}, k::Int; bias::Bool = true, @@ -305,6 +346,8 @@ function Flux.initialstates(cell::GConvLSTMCell) (zeros_like(cell.conv_x_i.weight, cell.out), zeros_like(cell.conv_x_i.weight, cell.out)) end +(cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) + function (cell::GConvLSTMCell)(g::GNNGraph, x::AbstractMatrix, (h, c)) if h isa AbstractVector h = repeat(h, 1, g.num_nodes) @@ -334,29 +377,74 @@ end """ - GConvLSTM(in => out, k; kws...) + GConvLSTM(args...; kws...) + +Construct a recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell. +It can be used to process an entire temporal sequence of node features at once. + +The arguments are passed to the [`GConvLSTMCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. + +# Examples + +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> layer = GConvLSTM(d_in => d_out, 2) +GNNRecurrence( + GConvLSTMCell(2 => 3, 2), # 168 parameters +) # Total: 24 arrays, 168 parameters, 2.023 KiB. + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +""" +GConvLSTM(args...; kws...) = GNNRecurrence(GConvLSTMCell(args...; kws...)) + +""" + DCGRUCell(in => out, k; [bias, init]) + +Diffusion Convolutional Recurrent Neural Network (DCGRU) cell from the paper +[Diffusion Convolutional Recurrent Neural Network: Data-driven Traffic Forecasting](https://arxiv.org/abs/1707.01926). + +Applyis a [`DConv`](@ref) layer to model spatial dependencies, +in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependencies. -The recurrent layer corresponding to the [`GConvLSTMCell`](@ref) cell, -used to process an entire temporal sequence of node features at once. +# Arguments -The arguments are the same as for [`GConvLSTMCell`](@ref). +- `in`: Number of input node features. +- `out`: Number of output node features. +- `k`: Diffusion step for the `DConv`. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. # Forward - layer(g::GNNGraph, x, [state]) + cell(g::GNNGraph, x, [h]) - `g`: The input graph. -- `x`: The time-varying node features. It should be an array of size `in x timesteps x num_nodes`. -- `state`: The initial hidden state of the LSTM cell. - If given, it is a tuple `(h, c)` where both elements are matrices of size `out x num_nodes`. - If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. + If not provided, it is assumed to be a matrix of zeros. -Applies the recurrent cell to each timestep of the input sequence and returns the output as -an array of size `out x timesteps x num_nodes`. +Performs one recurrence step and returns a tuple `(h, h)`, +where `h` is the updated hidden state of the GRU cell. # Examples ```jldoctest +julia> using GraphNeuralNetworks, Flux + julia> num_nodes, num_edges = 5, 10; julia> d_in, d_out = 2, 3; @@ -365,33 +453,98 @@ julia> timesteps = 5; julia> g = rand_graph(num_nodes, num_edges); -julia> x = rand(Float32, d_in, timesteps, num_nodes); +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; -julia> layer = GConvLSTM(d_in => d_out, 2); +julia> cell = DCGRUCell(d_in => d_out, 2); -julia> y = layer(g, x); +julia> state = Flux.initialstates(cell); -julia> size(y) # (d_out, timesteps, num_nodes) -(3, 5, 5) +julia> y = state; + +julia> for xt in x + y, state = cell(g, xt, state) + end + +julia> size(y) # (d_out, num_nodes) +(3, 5) ``` -""" -struct GConvLSTM{G<:GConvLSTMCell} <: GNNLayer - cell::G +""" +struct DCGRUCell + in::Int + out::Int + k::Int + dconv_u::DConv + dconv_r::DConv + dconv_c::DConv end -Flux.@layer GConvLSTM +Flux.@layer :noexpand DCGRUCell -function GConvLSTM(ch::Pair{Int,Int}, k::Int; kws...) - return GConvLSTM(GConvLSTMCell(ch, k; kws...)) +function DCGRUCell(ch::Pair{Int,Int}, k::Int; bias = true, init = glorot_uniform) + in, out = ch + dconv_u = DConv((in + out) => out, k; bias, init) + dconv_r = DConv((in + out) => out, k; bias, init) + dconv_c = DConv((in + out) => out, k; bias, init) + return DCGRUCell(in, out, k, dconv_u, dconv_r, dconv_c) end -Flux.initialstates(rnn::GConvLSTM) = Flux.initialstates(rnn.cell) +Flux.initialstates(cell::DCGRUCell) = zeros_like(cell.dconv_u.weights, cell.out) + +(cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) -function (rnn::GConvLSTM)(g::GNNGraph, x::AbstractArray) - return scan(rnn.cell, g, x, initialstates(rnn)) +function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) + return cell(g, x, repeat(h, 1, g.num_nodes)) end -function Base.show(io::IO, rnn::GConvLSTM) - print(io, "GConvLSTM($(rnn.cell.in) => $(rnn.cell.out), $(rnn.cell.k))") +function (cell::DCGRUCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + h̃ = vcat(x, h) + z = cell.dconv_u(g, h̃) + z = NNlib.sigmoid_fast.(z) + r = cell.dconv_r(g, h̃) + r = NNlib.sigmoid_fast.(r) + ĥ = vcat(x, h .* r) + c = cell.dconv_c(g, ĥ) + c = NNlib.tanh_fast.(c) + h = z.* h + (1 .- z) .* c + return h, h +end + +function Base.show(io::IO, cell::DCGRUCell) + print(io, "DCGRUCell($(cell.in) => $(cell.out), $(cell.k))") end +""" + DCGRU(args...; kws...) + +Construct a recurrent layer corresponding to the [`DCGRUCell`](@ref) cell. +It can be used to process an entire temporal sequence of node features at once. + +The arguments are passed to the [`DCGRUCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. + +# Examples +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> layer = DCGRU(d_in => d_out, 2) +GNNRecurrence( + DCGRUCell(2 => 3, 2), # 189 parameters +) # Total: 6 arrays, 189 parameters, 1.184 KiB. + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +""" +DCGRU(args...; kws...) = GNNRecurrence(DCGRUCell(args...; kws...)) + + diff --git a/GraphNeuralNetworks/src/layers/temporalconv_old.jl b/GraphNeuralNetworks/src/layers/temporalconv_old.jl index 634911f28..0c78eb09a 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv_old.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv_old.jl @@ -179,135 +179,6 @@ function Base.show(io::IO, a3tgcn::A3TGCN) end - - -# """ -# GConvLSTM(in => out, k, n; [bias, init, init_state]) - -# Graph Convolutional Long Short-Term Memory (GConvLSTM) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). - -# Performs a layer of ChebConv to model spatial dependencies, followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `k`: Chebyshev polynomial order. -# - `n`: Number of nodes in the graph. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -# julia> gclstm = GConvLSTM(2 => 5, 2, g1.num_nodes); - -# julia> y = gclstm(g1, x1); - -# julia> size(y) -# (5, 5) - -# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -# julia> z = gclstm(g2, x2); - -# julia> size(z) -# (5, 5, 30) -# ``` -# """ -# # GConvLSTM(ch, k, n; kwargs...) = Flux.Recur(GConvLSTMCell(ch, k, n; kwargs...)) -# # Flux.Recur(tgcn::GConvLSTMCell) = Flux.Recur(tgcn, tgcn.state0) - -# # (l::Flux.Recur{GConvLSTMCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -# # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph, x) = l(g, x) -# # _applylayer(l::Flux.Recur{GConvLSTMCell}, g::GNNGraph) = l(g) - -# struct DCGRUCell -# in::Int -# out::Int -# state0 -# k::Int -# dconv_u::DConv -# dconv_r::DConv -# dconv_c::DConv -# end - -# Flux.@layer DCGRUCell - -# function DCGRUCell(ch::Pair{Int,Int}, k::Int, n::Int; bias = true, init = glorot_uniform, init_state = Flux.zeros32) -# in, out = ch -# dconv_u = DConv((in + out) => out, k; bias=bias, init=init) -# dconv_r = DConv((in + out) => out, k; bias=bias, init=init) -# dconv_c = DConv((in + out) => out, k; bias=bias, init=init) -# state0 = init_state(out, n) -# return DCGRUCell(in, out, state0, k, dconv_u, dconv_r, dconv_c) -# end - -# function (dcgru::DCGRUCell)(h, g::GNNGraph, x) -# h̃ = vcat(x, h) -# z = dcgru.dconv_u(g, h̃) -# z = NNlib.sigmoid_fast.(z) -# r = dcgru.dconv_r(g, h̃) -# r = NNlib.sigmoid_fast.(r) -# ĥ = vcat(x, h .* r) -# c = dcgru.dconv_c(g, ĥ) -# c = tanh.(c) -# h = z.* h + (1 .- z) .* c -# return h, h -# end - -# function Base.show(io::IO, dcgru::DCGRUCell) -# print(io, "DCGRUCell($(dcgru.in) => $(dcgru.out), $(dcgru.k))") -# end - -# """ -# DCGRU(in => out, k, n; [bias, init, init_state]) - -# Diffusion Convolutional Recurrent Neural Network (DCGRU) layer from the paper [Diffusion Convolutional Recurrent Neural -# Network: Data-driven Traffic Forecasting](https://arxiv.org/pdf/1707.01926). - -# Performs a Diffusion Convolutional layer to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `k`: Diffusion step. -# - `n`: Number of nodes in the graph. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); - -# julia> dcgru = DCGRU(2 => 5, 2, g1.num_nodes); - -# julia> y = dcgru(g1, x1); - -# julia> size(y) -# (5, 5) - -# julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); - -# julia> z = dcgru(g2, x2); - -# julia> size(z) -# (5, 5, 30) -# ``` -# """ -# # DCGRU(ch, k, n; kwargs...) = Flux.Recur(DCGRUCell(ch, k, n; kwargs...)) -# # Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0) - -# # (l::Flux.Recur{DCGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) -# # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x) -# # _applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g) - # """ # EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) From 383d46444b1176a11f5650b52297961f2f27cb78 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 16 Dec 2024 12:19:09 +0100 Subject: [PATCH 04/11] EvolveGCNOCell --- .../src/GraphNeuralNetworks.jl | 4 +- .../src/layers/temporalconv.jl | 121 ++++++++++- .../src/layers/temporalconv_old.jl | 194 +++++++++--------- 3 files changed, 206 insertions(+), 113 deletions(-) diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index fd504f751..e2f4cb8f7 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -54,9 +54,9 @@ export GNNRecurrence, GConvGRU, GConvGRUCell, GConvLSTM, GConvLSTMCell, DCGRU, DCGRUCell, + EvolveGCNO, EvolveGCNOCell, TGCN, - A3TGCN, - EvolveGCNO + A3TGCN include("layers/pool.jl") export GlobalPool, diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index 8c57cfa0f..b109f4e6b 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -1,12 +1,23 @@ function scan(cell, g::GNNGraph, x::AbstractArray{T,3}, state) where {T} y = [] - for x_t in eachslice(x, dims = 2) - yt, state = cell(g, x_t, state) + for xt in eachslice(x, dims = 2) + yt, state = cell(g, xt, state) y = vcat(y, [yt]) end return stack(y, dims = 2) end +function scan(cell, tg::TemporalSnapshotsGNNGraph, x::AbstractVector, state) + # @assert length(x) == length(tg.snapshots) + y = [] + for (t, xt) in enumerate(x) + gt = tg.snapshots[t] + yt, state = cell(gt, xt, state) + y = vcat(y, [yt]) + end + return y +end + """ GNNRecurrence(cell) @@ -20,7 +31,7 @@ to process an entire temporal sequence of node features at once. - `g`: The input graph. - `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`. -- `state`: The initial state of the cell. +- `state`: The current state of the cell. If not provided, it is generated by calling `Flux.initialstates(cell)`. Applies the recurrent cell to each timestep of the input sequence and returns the output as @@ -61,11 +72,11 @@ Flux.@layer GNNRecurrence Flux.initialstates(rnn::GNNRecurrence) = Flux.initialstates(rnn.cell) -function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}) where {T} +function (rnn::GNNRecurrence)(g, x) return rnn(g, x, initialstates(rnn)) end -function (rnn::GNNRecurrence)(g::GNNGraph, x::AbstractArray{T,3}, state) where {T} +function (rnn::GNNRecurrence)(g, x, state) where {T} return scan(rnn.cell, g, x, state) end @@ -97,7 +108,7 @@ followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - `g`: The input graph. - `x`: The node features. It should be a matrix of size `in x num_nodes`. -- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. +- `h`: The current hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. If not provided, it is assumed to be a matrix of zeros. Performs one recurrence step and returns a tuple `(h, h)`, @@ -251,9 +262,9 @@ followed by a Long Short-Term Memory (LSTM) cell to model temporal dependencies. - `g`: The input graph. - `x`: The node features. It should be a matrix of size `in x num_nodes`. -- `state`: The initial hidden state of the LSTM cell. +- `state`: The current state of the LSTM cell. If given, it is a tuple `(h, c)` where both `h` and `c` are arrays of size `out x num_nodes`. - If not provided, the initial hidden state is assumed to be a tuple of matrices of zeros. + If not provided, it is assumed to be a tuple of matrices of zeros. Performs one recurrence step and returns a tuple `(output, state)`, where `output` is the updated hidden state `h` of the LSTM cell and `state` is the updated tuple `(h, c)`. @@ -434,7 +445,7 @@ in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependen - `g`: The input graph. - `x`: The node features. It should be a matrix of size `in x num_nodes`. -- `h`: The initial hidden state of the GRU cell. If given, it is a matrix of size `out x num_nodes`. +- `h`: The current state of the GRU cell. It is a matrix of size `out x num_nodes`. If not provided, it is assumed to be a matrix of zeros. Performs one recurrence step and returns a tuple `(h, h)`, @@ -547,4 +558,96 @@ julia> size(y) # (d_out, timesteps, num_nodes) """ DCGRU(args...; kws...) = GNNRecurrence(DCGRUCell(args...; kws...)) +"""" + EvolveGCNOCell(in => out; bias = true, init = glorot_uniform) + +Evolving Graph Convolutional Network cell of type "-O" from the paper +[EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/abs/1902.10191). + +Uses a [`GCNConv`](@ref) layer to model spatial dependencies, and an `LSTMCell` to model temporal dependencies. +Can work with time-varying graphs and node features. + +# Arguments + +- `in => out`: A pair where `in` is the number of input node features and `out` + is the number of output node features. +- `bias`: Add learnable bias for the convolution and the lstm cell. Default `true`. +- `init`: Weights' initializer for the convolution. Default `glorot_uniform`. + +# Forward + + cell(g::GNNGraph, x, [state]) -> x, state + +- `g`: The input graph. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `state`: The current state of the cell. + A state is a tuple `(weight, lstm)` where `weight` is the convolution's weight and `lstm` is the lstm's state. + If not provided, it is generated by calling `Flux.initialstates(cell)`. + +Returns the updated node features `x` and the updated state. + +```jldoctest +julia> using GraphNeuralNetworks, Flux + +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = [rand_graph(num_nodes, num_edges) for t in 1:timesteps]; + +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; + +julia> cell1 = EvolveGCNOCell(d_in => d_out) +EvolveGCNOCell(2 => 3) # 321 parameters + +julia> cell2 = EvolveGCNOCell(d_out => d_out) +EvolveGCNOCell(3 => 3) # 696 parameters + +julia> state1 = Flux.initialstates(cell1); + +julia> state2 = Flux.initialstates(cell2); + +julia> outputs = []; + +julia> for t in 1:timesteps + zt, state1 = cell1(g[t], x[t], state1) + yt, state2 = cell2(g[t], zt, state2) + outputs = vcat(outputs, [yt]) + end + +julia> size(outputs[end]) # (d_out, num_nodes) +(3, 5) +``` +""" +struct EvolveGCNOCell{C,L} <: GNNLayer + in::Int + out::Int + conv::C + lstm::L +end + +Flux.@layer :noexpand EvolveGCNOCell +function EvolveGCNOCell((in,out)::Pair{Int,Int}; bias = true, init = glorot_uniform) + conv = GCNConv(in => out; bias, init) + lstm = LSTMCell(in*out => in*out; bias) + return EvolveGCNOCell(in, out, conv, lstm) +end + +function Flux.initialstates(cell::EvolveGCNOCell) + weight = reshape(cell.conv.weight, :) + lstm = Flux.initialstates(cell.lstm) + return (; weight, lstm) +end + +function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state) + weight, state_lstm = cell.lstm(state.weight, state.lstm) + x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in))) + return x, (; weight, lstm = state_lstm) +end + +function Base.show(io::IO, egcno::EvolveGCNOCell) + print(io, "EvolveGCNOCell($(egcno.in) => $(egcno.out))") +end diff --git a/GraphNeuralNetworks/src/layers/temporalconv_old.jl b/GraphNeuralNetworks/src/layers/temporalconv_old.jl index 0c78eb09a..866a75098 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv_old.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv_old.jl @@ -1,13 +1,3 @@ -function scan(cell, g, x, state) - y = [] - for x_t in eachslice(x, dims = 2) - yt, state = cell(g, x_t, state) - y = vcat(y, [yt]) - end - return stack(y, dims = 2) -end - - struct TGCNCell{C,G} <: GNNLayer conv::C gru::G @@ -179,97 +169,97 @@ function Base.show(io::IO, a3tgcn::A3TGCN) end -# """ -# EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) - -# Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). - -# Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. - - -# # Arguments - -# - `in`: Number of input features. -# - `out`: Number of output features. -# - `bias`: Add learnable bias. Default `true`. -# - `init`: Weights' initializer. Default `glorot_uniform`. -# - `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# # Examples - -# ```jldoctest -# julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) -# TemporalSnapshotsGNNGraph: -# num_nodes: [10, 10, 10] -# num_edges: [20, 14, 22] -# num_snapshots: 3 - -# julia> ev = EvolveGCNO(4 => 5) -# EvolveGCNO(4 => 5) - -# julia> size(ev(tg, tg.ndata.x)) -# (3,) - -# julia> size(ev(tg, tg.ndata.x)[1]) -# (5, 10) -# ``` -# """ -# struct EvolveGCNO -# conv -# W_init -# init_state -# in::Int -# out::Int -# Wf -# Uf -# Bf -# Wi -# Ui -# Bi -# Wo -# Uo -# Bo -# Wc -# Uc -# Bc -# end +""" + EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + +Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). + +Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. + + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) +TemporalSnapshotsGNNGraph: + num_nodes: [10, 10, 10] + num_edges: [20, 14, 22] + num_snapshots: 3 + +julia> ev = EvolveGCNO(4 => 5) +EvolveGCNO(4 => 5) + +julia> size(ev(tg, tg.ndata.x)) +(3,) + +julia> size(ev(tg, tg.ndata.x)[1]) +(5, 10) +``` +""" +struct EvolveGCNO + conv + W_init + init_state + in::Int + out::Int + Wf + Uf + Bf + Wi + Ui + Bi + Wo + Uo + Bo + Wc + Uc + Bc +end -# function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) -# in, out = ch -# W = init(out, in) -# conv = GCNConv(ch; bias = bias, init = init) -# Wf = init(out, in) -# Uf = init(out, in) -# Bf = bias ? init(out, in) : nothing -# Wi = init(out, in) -# Ui = init(out, in) -# Bi = bias ? init(out, in) : nothing -# Wo = init(out, in) -# Uo = init(out, in) -# Bo = bias ? init(out, in) : nothing -# Wc = init(out, in) -# Uc = init(out, in) -# Bc = bias ? init(out, in) : nothing -# return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) -# end - -# function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) -# H = egcno.init_state(egcno.out, egcno.in) -# C = egcno.init_state(egcno.out, egcno.in) -# W = egcno.W_init -# X = map(1:tg.num_snapshots) do i -# F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) -# I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) -# O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) -# C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) -# C = F .* C + I .* C̃ -# H = O .* tanh_fast.(C) -# W = H -# egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) -# end -# return X -# end +function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) + in, out = ch + W = init(out, in) + conv = GCNConv(ch; bias = bias, init = init) + Wf = init(out, in) + Uf = init(out, in) + Bf = bias ? init(out, in) : nothing + Wi = init(out, in) + Ui = init(out, in) + Bi = bias ? init(out, in) : nothing + Wo = init(out, in) + Uo = init(out, in) + Bo = bias ? init(out, in) : nothing + Wc = init(out, in) + Uc = init(out, in) + Bc = bias ? init(out, in) : nothing + return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) +end + +function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) + H = egcno.init_state(egcno.out, egcno.in) + C = egcno.init_state(egcno.out, egcno.in) + W = egcno.W_init + X = map(1:tg.num_snapshots) do i + F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) + I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) + O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) + C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) + C = F .* C + I .* C̃ + H = O .* tanh_fast.(C) + W = H + egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) + end + return X +end -# function Base.show(io::IO, egcno::EvolveGCNO) -# print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") -# end +function Base.show(io::IO, egcno::EvolveGCNO) + print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") +end From 13142f27b8b3006665014758f6b41fe64964a3ab Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 16 Dec 2024 12:19:52 +0100 Subject: [PATCH 05/11] cleanup --- .../src/layers/temporalconv_old.jl | 96 ------------------- 1 file changed, 96 deletions(-) diff --git a/GraphNeuralNetworks/src/layers/temporalconv_old.jl b/GraphNeuralNetworks/src/layers/temporalconv_old.jl index 866a75098..c1ea8ed14 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv_old.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv_old.jl @@ -167,99 +167,3 @@ end function Base.show(io::IO, a3tgcn::A3TGCN) print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))") end - - -""" - EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) - -Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191). - -Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph. - - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. -- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`. - -# Examples - -```jldoctest -julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))]) -TemporalSnapshotsGNNGraph: - num_nodes: [10, 10, 10] - num_edges: [20, 14, 22] - num_snapshots: 3 - -julia> ev = EvolveGCNO(4 => 5) -EvolveGCNO(4 => 5) - -julia> size(ev(tg, tg.ndata.x)) -(3,) - -julia> size(ev(tg, tg.ndata.x)[1]) -(5, 10) -``` -""" -struct EvolveGCNO - conv - W_init - init_state - in::Int - out::Int - Wf - Uf - Bf - Wi - Ui - Bi - Wo - Uo - Bo - Wc - Uc - Bc -end - -function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32) - in, out = ch - W = init(out, in) - conv = GCNConv(ch; bias = bias, init = init) - Wf = init(out, in) - Uf = init(out, in) - Bf = bias ? init(out, in) : nothing - Wi = init(out, in) - Ui = init(out, in) - Bi = bias ? init(out, in) : nothing - Wo = init(out, in) - Uo = init(out, in) - Bo = bias ? init(out, in) : nothing - Wc = init(out, in) - Uc = init(out, in) - Bc = bias ? init(out, in) : nothing - return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc) -end - -function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x) - H = egcno.init_state(egcno.out, egcno.in) - C = egcno.init_state(egcno.out, egcno.in) - W = egcno.W_init - X = map(1:tg.num_snapshots) do i - F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf) - I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi) - O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo) - C̃ = Flux.tanh_fast.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc) - C = F .* C + I .* C̃ - H = O .* tanh_fast.(C) - W = H - egcno.conv(tg.snapshots[i], x[i]; conv_weight = H) - end - return X -end - -function Base.show(io::IO, egcno::EvolveGCNO) - print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))") -end From b229ab27168b1d3dcdf0e084bdaab3149282a3fb Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 16 Dec 2024 18:32:37 +0100 Subject: [PATCH 06/11] EvolveGCNO --- .../src/layers/temporalconv.jl | 115 ++++++++++++++++-- 1 file changed, 106 insertions(+), 9 deletions(-) diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index b109f4e6b..51575caf3 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -22,23 +22,42 @@ end """ GNNRecurrence(cell) -Construct a recurrent layer that applies the `cell` -to process an entire temporal sequence of node features at once. +Construct a recurrent layer that applies the graph recurrent `cell` forward +multiple times to process an entire temporal sequence of node features at once. + +The `cell` has to satisfy the following interface for the forward pass: +`yt, state = cell(g, xt, state)`, where `xt` is the input node features, +`yt` is the updated node features, `state` is the cell state to be updated. # Forward - layer(g::GNNGraph, x, [state]) + layer(g, x, [state]) -- `g`: The input graph. -- `x`: The time-varying node features. An array of size `in x timesteps x num_nodes`. -- `state`: The current state of the cell. +Applies the recurrent cell to each timestep of the input sequence. + +## Arguments + +- `g`: The input `GNNGraph` or `TemporalSnapshotsGNNGraph`. + - If `GNNGraph`, the same graph is used for all timesteps. + - If `TemporalSnapshotsGNNGraph`, a different graph is used for each timestep. Not all cells support this. +- `x`: The time-varying node features. + - If `g` is `GNNGraph`, it is an array of size `in x timesteps x num_nodes`. + - If `g` is `TemporalSnapshotsGNNGraph`, it is an vector of length `timesteps`, + with element `t` of size `in x num_nodes_t`. +- `state`: The initial state for the cell. If not provided, it is generated by calling `Flux.initialstates(cell)`. -Applies the recurrent cell to each timestep of the input sequence and returns the output as -an array of size `out_features x timesteps x num_nodes`. +## Return + +Returns the updated node features: +- If `g` is `GNNGraph`, returns an array of size `out_features x timesteps x num_nodes`. +- If `g` is `TemporalSnapshotsGNNGraph`, returns a vector of length `timesteps`, + with element `t` of size `out_features x num_nodes_t`. # Examples +The following example considers a static graph and a time-varying node features. + ```jldoctest julia> num_nodes, num_edges = 5, 10; @@ -47,6 +66,9 @@ julia> d_in, d_out = 2, 3; julia> timesteps = 5; julia> g = rand_graph(num_nodes, num_edges); +GNNGraph: + num_nodes: 5 + num_edges: 10 julia> x = rand(Float32, d_in, timesteps, num_nodes); @@ -63,6 +85,38 @@ julia> y = layer(g, x); julia> size(y) # (d_out, timesteps, num_nodes) (3, 5, 5) ``` +Now consider a time-varying graph and time-varying node features. +```jldoctest +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> num_nodes = [10, 10, 10, 10, 10]; + +julia> num_edges = [10, 12, 14, 16, 18]; + +julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) + +julia> x = [rand(Float32, d_in, n) for n in num_nodes]; + +julia> cell = EvolveGCNOCell(d_in => d_out) +EvolveGCNOCell(2 => 3) # 321 parameters + +julia> layer = GNNRecurrence(cell) +GNNRecurrence( + EvolveGCNOCell(2 => 3), # 321 parameters +) # Total: 5 arrays, 321 parameters, 1.535 KiB. + +julia> y = layer(tg, x); + +julia> length(y) # timesteps +5 + +julia> size(y[end]) # (d_out, num_nodes[end]) +(3, 10) +``` """ struct GNNRecurrence{G} <: GNNLayer cell::G @@ -437,7 +491,7 @@ in combination with a Gated Recurrent Unit (GRU) cell to model temporal dependen - `out`: Number of output node features. - `k`: Diffusion step for the `DConv`. - `bias`: Add learnable bias. Default `true`. -- `init`: Weights' initializer. Default `glorot_uniform`. +- `init`: Convolution weights' initializer. Default `glorot_uniform`. # Forward @@ -651,3 +705,46 @@ end function Base.show(io::IO, egcno::EvolveGCNOCell) print(io, "EvolveGCNOCell($(egcno.in) => $(egcno.out))") end + + +""" + EvolveGCNO(args...; kws...) + +Construct a recurrent layer corresponding to the [`EvolveGCNOCell`](@ref) cell. +It can be used to process an entire temporal sequence of graphs and node features at once. + +The arguments are passed to the [`EvolveGCNOCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. + +# Examples + +```jldoctest +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> num_nodes = [10, 10, 10, 10, 10]; + +julia> num_edges = [10, 12, 14, 16, 18]; + +julia> snapshots = [rand_graph(n, m) for (n, m) in zip(num_nodes, num_edges)]; + +julia> tg = TemporalSnapshotsGNNGraph(snapshots) + +julia> x = [rand(Float32, d_in, n) for n in num_nodes]; + +julia> cell = EvolveGCNO(d_in => d_out) +GNNRecurrence( + EvolveGCNOCell(2 => 3), # 321 parameters +) # Total: 5 arrays, 321 parameters, 1.535 KiB. + +julia> y = layer(tg, x); + +julia> length(y) # timesteps +5 + +julia> size(y[end]) # (d_out, num_nodes[end]) +(3, 10) +``` +""" +EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...)) From 23d9e45b575682449b583d1399508ecfe2425d0c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 16 Dec 2024 19:20:05 +0100 Subject: [PATCH 07/11] TGCNCell --- .../src/layers/temporalconv.jl | 48 +++++++++++++++++++ .../src/layers/temporalconv_old.jl | 3 +- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index 51575caf3..32ac11c71 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -748,3 +748,51 @@ julia> size(y[end]) # (d_out, num_nodes[end]) ``` """ EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...)) + + + +@concrete struct TGCNCell <: GNNLayer + in::Int + out::Int + conv_z + dense_z + conv_r + dense_r + conv_h + dense_h +end + +Flux.@layer :noexpand TGCNCell + +function TGCNCell((in, out)::Pair{Int, Int}; kws...) + conv_z = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...)) + dense_z = Dense(2*out => out, sigmoid) + conv_r = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...)) + dense_r = Dense(2*out => out, sigmoid) + conv_h = GNNChain(GCNConv(in => out, relu; kws...), GCNConv(out => out; kws...)) + dense_h = Dense(2*out => out, tanh) + return TGCNCell(in, out, conv_z, dense_z, conv_r, dense_r, conv_h, dense_h) +end + +Flux.initialstates(cell::TGCNCell) = zeros_like(cell.dense_z.weight, cell.out) + +(cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) + +function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractVector) + return cell(g, x, repeat(h, 1, g.num_nodes)) +end + +function (cell::TGCNCell)(g::GNNGraph, x::AbstractMatrix, h::AbstractMatrix) + z = cell.conv_z(g, x) + z = cell.dense_z(vcat(z, h)) + r = cell.conv_r(g, x) + r = cell.dense_r(vcat(r, h)) + h̃ = cell.conv_h(g, x) + h̃ = cell.dense_h(vcat(h̃, r .* h)) + h = (1 .- z) .* h .+ z .* h̃ + return h, h +end + +function Base.show(io::IO, cell::TGCNCell) + print(io, "TGCNCell($(cell.in) => $(cell.out))") +end \ No newline at end of file diff --git a/GraphNeuralNetworks/src/layers/temporalconv_old.jl b/GraphNeuralNetworks/src/layers/temporalconv_old.jl index c1ea8ed14..9be5abfa1 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv_old.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv_old.jl @@ -7,11 +7,10 @@ end Flux.@layer :noexpand TGCNCell -function TGCNCell(ch::Pair{Int, Int}; +function TGCNCell((in, out)::Pair{Int, Int}; bias::Bool = true, init = Flux.glorot_uniform, add_self_loops = false) - in, out = ch conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops) gru = GRUCell(out => out) return TGCNCell(conv, gru, in, out) From d762110c81cad203a21ee375a266fa6cff6acc64 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 17 Dec 2024 04:29:53 +0100 Subject: [PATCH 08/11] TGCCN --- .../src/GraphNeuralNetworks.jl | 3 +- .../src/layers/temporalconv.jl | 89 +++++++++- .../src/layers/temporalconv_old.jl | 168 ------------------ 3 files changed, 88 insertions(+), 172 deletions(-) delete mode 100644 GraphNeuralNetworks/src/layers/temporalconv_old.jl diff --git a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl index e2f4cb8f7..745e46aaa 100644 --- a/GraphNeuralNetworks/src/GraphNeuralNetworks.jl +++ b/GraphNeuralNetworks/src/GraphNeuralNetworks.jl @@ -55,8 +55,7 @@ export GNNRecurrence, GConvLSTM, GConvLSTMCell, DCGRU, DCGRUCell, EvolveGCNO, EvolveGCNOCell, - TGCN, - A3TGCN + TGCN, TGCNCell include("layers/pool.jl") export GlobalPool, diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index 32ac11c71..d0cd4b757 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -130,7 +130,7 @@ function (rnn::GNNRecurrence)(g, x) return rnn(g, x, initialstates(rnn)) end -function (rnn::GNNRecurrence)(g, x, state) where {T} +function (rnn::GNNRecurrence)(g, x, state) return scan(rnn.cell, g, x, state) end @@ -750,7 +750,60 @@ julia> size(y[end]) # (d_out, num_nodes[end]) EvolveGCNO(args...; kws...) = GNNRecurrence(EvolveGCNOCell(args...; kws...)) +""" + TGCNCell(in => out; kws...) + +Recurrent graph convolutional cell from the paper +[T-GCN: A Temporal Graph Convolutional +Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320). + +Uses two stacked [`GCNConv`](@ref) layers to model spatial dependencies, +and a GRU mechanism to model temporal dependencies. + +`in` and `out` are the number of input and output node features, respectively. +The keyword arguments are passed to the [`GCNConv`](@ref) constructor. + +# Forward + + cell(g::GNNGraph, x, [state]) + +- `g`: The input graph. +- `x`: The node features. It should be a matrix of size `in x num_nodes`. +- `state`: The current state of the cell. + If not provided, it is generated by calling `Flux.initialstates(cell)`. + The state is a matrix of size `out x num_nodes`. + +Returns the updated node features and the updated state. + +# Examples + +```jldoctest +julia> using GraphNeuralNetworks, Flux + +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); +julia> x = [rand(Float32, d_in, num_nodes) for t in 1:timesteps]; + +julia> cell = DCGRUCell(d_in => d_out, 2); + +julia> state = Flux.initialstates(cell); + +julia> y = state; + +julia> for xt in x + y, state = cell(g, xt, state) + end + +julia> size(y) # (d_out, num_nodes) +(3, 5) +``` +""" @concrete struct TGCNCell <: GNNLayer in::Int out::Int @@ -795,4 +848,36 @@ end function Base.show(io::IO, cell::TGCNCell) print(io, "TGCNCell($(cell.in) => $(cell.out))") -end \ No newline at end of file +end + +""" + TGCN(args...; kws...) + +Construct a recurrent layer corresponding to the [`TGCNCell`](@ref) cell. + +The arguments are passed to the [`TGCNCell`](@ref) constructor. +See [`GNNRecurrence`](@ref) for more details. + +# Examples + +```jldoctest +julia> num_nodes, num_edges = 5, 10; + +julia> d_in, d_out = 2, 3; + +julia> timesteps = 5; + +julia> g = rand_graph(num_nodes, num_edges); + +julia> x = rand(Float32, d_in, timesteps, num_nodes); + +julia> layer = TGCN(d_in => d_out) + +julia> y = layer(g, x); + +julia> size(y) # (d_out, timesteps, num_nodes) +(3, 5, 5) +``` +""" +TGCN(args...; kws...) = GNNRecurrence(TGCNCell(args...; kws...)) + diff --git a/GraphNeuralNetworks/src/layers/temporalconv_old.jl b/GraphNeuralNetworks/src/layers/temporalconv_old.jl deleted file mode 100644 index 9be5abfa1..000000000 --- a/GraphNeuralNetworks/src/layers/temporalconv_old.jl +++ /dev/null @@ -1,168 +0,0 @@ -struct TGCNCell{C,G} <: GNNLayer - conv::C - gru::G - in::Int - out::Int -end - -Flux.@layer :noexpand TGCNCell - -function TGCNCell((in, out)::Pair{Int, Int}; - bias::Bool = true, - init = Flux.glorot_uniform, - add_self_loops = false) - conv = GCNConv(in => out, sigmoid; init, bias, add_self_loops) - gru = GRUCell(out => out) - return TGCNCell(conv, gru, in, out) -end - -Flux.initialstates(cell::TGCNCell) = initialstates(cell.gru) -(cell::TGCNCell)(g::GNNGraph, x::AbstractVecOrMat) = cell(g, x, initialstates(cell)) - -function (cell::TGCNCell)(g::GNNGraph, x::AbstractVecOrMat, h::AbstractVecOrMat) - x = cell.conv(g, x) - h = cell.gru(x, h) - return h -end - -function Base.show(io::IO, cell::TGCNCell) - print(io, "TGCNCell($(cell.in) => $(cell.out))") -end - -""" - TGCN(in => out; [bias, init, add_self_loops]) - -Temporal Graph Convolutional Network (T-GCN) recurrent layer from the paper [T-GCN: A Temporal Graph Convolutional Network for Traffic Prediction](https://arxiv.org/pdf/1811.05320.pdf). - -Performs a layer of GCNConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `bias`: Add learnable bias. Default `true`. -- `init`: Convolution's weights initializer. Default `glorot_uniform`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. - -# Forward - - tgcn(g::GNNGraph, x, [h]) - -- `g`: The input graph. -- `x`: The input to the TGCN. It should be a matrix size `in x timesteps` or an array of size `in x timesteps x num_nodes`. -- `h`: The initial hidden state of the GRU cell. If given, it is a vector of size `out` or a matrix of size `out x num_nodes`. - If not provided, it is assumed to be a vector of zeros. - -# Examples - -```jldoctest -julia> in, out = 2, 3; - -julia> tgcn = TGCN(in => out) -TGCN( - TGCNCell( - GCNConv(2 => 3, σ), # 9 parameters - GRUCell(3 => 3), # 63 parameters - ), -) # Total: 5 arrays, 72 parameters, 560 bytes. - -julia> num_nodes = 5; num_edges = 10; timesteps = 4; - -julia> g = rand_graph(num_nodes, num_edges); - -julia> x = rand(Float32, in, timesteps, num_nodes); - -julia> tgcn(g, x) |> size -(3, 4, 5) -``` -""" -struct TGCN{C<:TGCNCell} <: GNNLayer - cell::C -end - -Flux.@layer TGCN - -TGCN(ch::Pair{Int, Int}; kws...) = TGCN(TGCNCell(ch; kws...)) - -initialstates(tgcn::TGCN) = initialstates(tgcn.cell) - -(tgcn::TGCN)(g::GNNGraph, x) = tgcn(g, x, initialstates(tgcn)) - -function (tgcn::TGCN)(g::GNNGraph, x::AbstractArray, h) - return scan(tgcn.cell, g, x, h) -end - -Base.show(io::IO, tgcn::TGCN) = print(io, "TGCN($(tgcn.cell.in) => $(tgcn.cell.out))") - -####### TO BE PORTED TO FLUX v0.15 from here ############################ - -""" - A3TGCN(in => out; [bias, init, add_self_loops]) - -Attention Temporal Graph Convolutional Network (A3T-GCN) model from the paper [A3T-GCN: Attention Temporal Graph -Convolutional Network for Traffic Forecasting](https://arxiv.org/pdf/2006.11583.pdf). - -Performs a TGCN layer, followed by a soft attention layer. - -# Arguments - -- `in`: Number of input features. -- `out`: Number of output features. -- `bias`: Add learnable bias. Default `true`. -- `init`: Convolution's weights initializer. Default `glorot_uniform`. -- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. - -# Examples - -```jldoctest -julia> in, out = 2, 3; - -julia> model = A3TGCN(in => out) -TGCN( - TGCNCell( - GCNConv(2 => 3, σ), # 9 parameters - GRUCell(3 => 3), # 63 parameters - ), -) # Total: 5 arrays, 72 parameters, 560 bytes. - -julia> num_nodes = 5; num_edges = 10; timesteps = 4; - -julia> g = rand_graph(num_nodes, num_edges); - -julia> x = rand(Float32, in, timesteps, num_nodes); - -julia> model(g, x) |> size -(3, 4, 5) -``` -""" -struct A3TGCN <: GNNLayer - tgcn::TGCN - dense1::Dense - dense2::Dense - in::Int - out::Int -end - -Flux.@layer A3TGCN - -function A3TGCN(ch::Pair{Int, Int}; kws...) - in, out = ch - tgcn = TGCN(in => out; kws...) - dense1 = Dense(out => out) - dense2 = Dense(out => out) - return A3TGCN(tgcn, dense1, dense2, in, out) -end - -function (a3tgcn::A3TGCN)(g::GNNGraph, x::AbstractArray, h) - h = a3tgcn.tgcn(g, x, h) # [out, timesteps, num_nodes] - logits = a3tgcn.dense1(h) - logits = a3tgcn.dense2(logits) # [out, timesteps, num_nodes] - a = softmax(logits, dims=2) # TODO handle multiple graphs - c = sum(a .* h, dims=2) - c = dropdims(c, dims=2) # [out, num_nodes] - return c -end - -function Base.show(io::IO, a3tgcn::A3TGCN) - print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))") -end From 6a23a70ede1970fb7e797e4ef2e8530c7ecf343f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 17 Dec 2024 09:52:58 +0100 Subject: [PATCH 09/11] tests --- .../src/layers/temporalconv.jl | 2 + .../test/layers/temporalconv.jl | 211 +++++++++++++----- GraphNeuralNetworks/test/test_module.jl | 1 + 3 files changed, 152 insertions(+), 62 deletions(-) diff --git a/GraphNeuralNetworks/src/layers/temporalconv.jl b/GraphNeuralNetworks/src/layers/temporalconv.jl index d0cd4b757..bc12f2fab 100644 --- a/GraphNeuralNetworks/src/layers/temporalconv.jl +++ b/GraphNeuralNetworks/src/layers/temporalconv.jl @@ -696,6 +696,8 @@ function Flux.initialstates(cell::EvolveGCNOCell) return (; weight, lstm) end +(cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix) = cell(g, x, initialstates(cell)) + function (cell::EvolveGCNOCell)(g::GNNGraph, x::AbstractMatrix, state) weight, state_lstm = cell.lstm(state.weight, state.lstm) x = cell.conv(g, x, conv_weight = reshape(weight, (cell.out, cell.in))) diff --git a/GraphNeuralNetworks/test/layers/temporalconv.jl b/GraphNeuralNetworks/test/layers/temporalconv.jl index 277783d4f..f8d96c0e2 100644 --- a/GraphNeuralNetworks/test/layers/temporalconv.jl +++ b/GraphNeuralNetworks/test/layers/temporalconv.jl @@ -1,16 +1,20 @@ @testmodule TemporalConvTestModule begin using GraphNeuralNetworks - export in_channel, out_channel, N, timesteps, g, tg, RTOL_LOW, RTOL_HIGH, ATOL_LOW + using Statistics + export in_channel, out_channel, N, timesteps, g, tg, cell_loss, + RTOL_LOW, ATOL_LOW, RTOL_HIGH RTOL_LOW = 1e-2 - RTOL_HIGH = 1e-5 ATOL_LOW = 1e-3 + RTOL_HIGH = 1e-5 in_channel = 3 out_channel = 5 N = 4 timesteps = 5 + cell_loss(cell, g, x...) = mean(cell(g, x...)[1]) + g = GNNGraph(rand_graph(N, 8), ndata = rand(Float32, in_channel, N), graph_type = :coo) @@ -22,80 +26,163 @@ end @testitem "TGCNCell" setup=[TemporalConvTestModule, TestModule] begin using .TemporalConvTestModule, .TestModule cell = GraphNeuralNetworks.TGCNCell(in_channel => out_channel) - h = cell(g, g.x) + y, h = cell(g, g.x) + @test y === h @test size(h) == (out_channel, g.num_nodes) - test_gradients(cell, g, g.x, rtol = RTOL_HIGH) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_HIGH) + # with initial state + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_HIGH) end @testitem "TGCN" setup=[TemporalConvTestModule, TestModule] begin using .TemporalConvTestModule, .TestModule - tgcn = TGCN(in_channel => out_channel) + layer = TGCN(in_channel => out_channel) x = rand(Float32, in_channel, timesteps, g.num_nodes) - h = tgcn(g, x) - @test size(h) == (out_channel, timesteps, g.num_nodes) - test_gradients(tgcn, g, x, rtol = RTOL_HIGH) - test_gradients(tgcn, g, x, h[:,1,:], rtol = RTOL_HIGH) - - # model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) - # @test size(model(g1, g1.ndata.x)) == (1, N) - # @test model(g1) isa GNNGraph + state0 = rand(Float32, out_channel, g.num_nodes) + y = layer(g, x) + @test layer isa GNNRecurrence + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol = RTOL_HIGH) + # with initial state + test_gradients(layer, g, x, state0, rtol = RTOL_HIGH) + + # interplay with GNNChain + model = GNNChain(TGCN(in_channel => out_channel), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol = RTOL_HIGH, atol = ATOL_LOW) end -# @testitem "A3TGCN" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# a3tgcn = A3TGCN(in_channel => out_channel) -# @test size(Flux.gradient(x -> sum(a3tgcn(g1, x)), g1.ndata.x)[1]) == (in_channel, N) -# model = GNNChain(A3TGCN(in_channel => out_channel), Dense(out_channel, 1)) -# @test size(model(g1, g1.ndata.x)) == (1, N) -# @test model(g1) isa GNNGraph -# end +@testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + cell = GConvLSTMCell(in_channel => out_channel, 2) + y, (h, c) = cell(g, g.x) + @test y === h + @test size(h) == (out_channel, g.num_nodes) + @test size(c) == (out_channel, g.num_nodes) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(cell, g, g.x, (h, c), loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) +end -# @testitem "GConvLSTMCell" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# gconvlstm = GraphNeuralNetworks.GConvLSTMCell(in_channel => out_channel, 2, g1.num_nodes) -# (h, c), h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) -# @test size(h) == (out_channel, N) -# @test size(c) == (out_channel, N) -# end +@testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + layer = GConvLSTM(in_channel => out_channel, 2) + @test layer isa GNNRecurrence + x = rand(Float32, in_channel, timesteps, g.num_nodes) + state0 = (rand(Float32, out_channel, g.num_nodes), rand(Float32, out_channel, g.num_nodes)) + y = layer(g, x) + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + + # interplay with GNNChain + model = GNNChain(GConvLSTM(in_channel => out_channel, 2), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) +end -# @testitem "GConvLSTM" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# gconvlstm = GConvLSTM(in_channel => out_channel, 2, g1.num_nodes) -# @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) -# model = GNNChain(GConvLSTM(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) -# end +@testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + cell = GConvGRUCell(in_channel => out_channel, 2) + y, h = cell(g, g.x) + @test y === h + @test size(h) == (out_channel, g.num_nodes) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) +end -# @testitem "GConvGRUCell" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes) -# h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) -# @test size(h) == (out_channel, N) -# end -# @testitem "GConvGRU" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# gconvlstm = GConvGRU(in_channel => out_channel, 2, g1.num_nodes) -# @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) -# model = GNNChain(GConvGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) -# @test size(model(g1, g1.ndata.x)) == (1, N) -# @test model(g1) isa GNNGraph -# end +@testitem "GConvGRU" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + layer = GConvGRU(in_channel => out_channel, 2) + @test layer isa GNNRecurrence + x = rand(Float32, in_channel, timesteps, g.num_nodes) + state0 = rand(Float32, out_channel, g.num_nodes) + y = layer(g, x) + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + + # interplay with GNNChain + model = GNNChain(GConvGRU(in_channel => out_channel, 2), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) +end -# @testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# dcgru = DCGRU(in_channel => out_channel, 2, g1.num_nodes) -# @test size(Flux.gradient(x -> sum(dcgru(g1, x)), g1.ndata.x)[1]) == (in_channel, N) -# model = GNNChain(DCGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) -# @test size(model(g1, g1.ndata.x)) == (1, N) -# @test model(g1) isa GNNGraph -# end +@testitem "DCGRUCell" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + cell = DCGRUCell(in_channel => out_channel, 2) + y, h = cell(g, g.x) + @test y === h + @test size(h) == (out_channel, g.num_nodes) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(cell, g, g.x, h, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) +end -# @testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin -# using .TemporalConvTestModule, .TestModule -# evolvegcno = EvolveGCNO(in_channel => out_channel) -# @test length(Flux.gradient(x -> sum(sum(evolvegcno(tg, x))), tg.ndata.x)[1]) == S -# @test size(evolvegcno(tg, tg.ndata.x)[1]) == (out_channel, N) -# end +@testitem "DCGRU" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + layer = DCGRU(in_channel => out_channel, 2) + @test layer isa GNNRecurrence + x = rand(Float32, in_channel, timesteps, g.num_nodes) + state0 = rand(Float32, out_channel, g.num_nodes) + y = layer(g, x) + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + + # interplay with GNNChain + model = GNNChain(DCGRU(in_channel => out_channel, 2), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol = RTOL_LOW, atol = ATOL_LOW) +end + +@testitem "EvolveGCNOCell" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + cell = EvolveGCNOCell(in_channel => out_channel) + y, state = cell(g, g.x) + @test size(y) == (out_channel, g.num_nodes) + # with no initial state + test_gradients(cell, g, g.x, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(cell, g, g.x, state, loss=cell_loss, rtol=RTOL_LOW, atol=ATOL_LOW) +end + +@testitem "EvolveGCNO" setup=[TemporalConvTestModule, TestModule] begin + using .TemporalConvTestModule, .TestModule + layer = EvolveGCNO(in_channel => out_channel) + @test layer isa GNNRecurrence + x = rand(Float32, in_channel, timesteps, g.num_nodes) + state0 = Flux.initialstates(layer) + y = layer(g, x) + @test size(y) == (out_channel, timesteps, g.num_nodes) + # with no initial state + test_gradients(layer, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) + # with initial state + test_gradients(layer, g, x, state0, rtol=RTOL_LOW, atol=ATOL_LOW) + + # interplay with GNNChain + model = GNNChain(EvolveGCNO(in_channel => out_channel), Dense(out_channel, 1)) + y = model(g, x) + @test size(y) == (1, timesteps, g.num_nodes) + test_gradients(model, g, x, rtol=RTOL_LOW, atol=ATOL_LOW) +end # @testitem "GINConv" setup=[TemporalConvTestModule, TestModule] begin # using .TemporalConvTestModule, .TestModule diff --git a/GraphNeuralNetworks/test/test_module.jl b/GraphNeuralNetworks/test/test_module.jl index f21e0a298..8f7a0446b 100644 --- a/GraphNeuralNetworks/test/test_module.jl +++ b/GraphNeuralNetworks/test/test_module.jl @@ -71,6 +71,7 @@ function check_equal_leaves(a, b; rtol=1e-4, atol=1e-4) # @assert isapprox(x, y; rtol, atol) if !isapprox(x, y; rtol, atol) equal = false + # @show x y end end end From cc387e43095b7aa59ca4b6c921671e09830986c4 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 17 Dec 2024 11:26:13 +0100 Subject: [PATCH 10/11] fix gatedgraphconv --- GNNLux/src/layers/conv.jl | 6 +++++- GNNlib/src/layers/conv.jl | 3 +-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index f92dd1ec6..63c4f90b4 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -1261,7 +1261,11 @@ LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l function (l::GatedGraphConv)(g, x, ps, st) gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) - fgru = (x, h) -> gru((x, (h,)))[1] # make the forward compatible with Flux.GRUCell style + # make the forward compatible with Flux.GRUCell style + function fgru(x, h) + y, (h, ) = gru((x, (h,))) + return y, h + end m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims) return GNNlib.gated_graph_conv(m, g, x), st end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index fe7c27d9c..bd9bd18b3 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -227,8 +227,7 @@ function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix) for i in 1:(l.num_layers) m = view(l.weight, :, :, i) * h m = propagate(copy_xj, g, l.aggr; xj = m) - # in gru forward, hidden state is first argument, input is second - h = l.gru(m, h) + _, h = l.gru(m, h) end return h end From 83207c444e26620400da43ba543a66236eae8080 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Tue, 17 Dec 2024 11:47:08 +0100 Subject: [PATCH 11/11] fix set2set --- GNNlib/src/layers/pool.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GNNlib/src/layers/pool.jl b/GNNlib/src/layers/pool.jl index 991e18465..40f983689 100644 --- a/GNNlib/src/layers/pool.jl +++ b/GNNlib/src/layers/pool.jl @@ -31,9 +31,9 @@ function set2set_pool(l, g::GNNGraph, x::AbstractMatrix) qstar = zeros_like(x, (2*n_in, g.num_graphs)) h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2)) + state = (h, c) for t in 1:l.num_iters - h, c = l.lstm(qstar, (h, c)) # [n_in, n_graphs] - q = h + q, state = l.lstm(qstar, state) # [n_in, n_graphs] qn = broadcast_nodes(g, q) # [n_in, n_nodes] α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes] r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]