Skip to content

Commit

Permalink
rewrite recurrent temporal layers for Flux v0.16 (#560)
Browse files Browse the repository at this point in the history
* GConvGRU

* GConvLSTM

* GNNRecurrence

* EvolveGCNOCell

* cleanup

* EvolveGCNO

* TGCNCell

* TGCCN

* tests

* fix gatedgraphconv

* fix set2set
  • Loading branch information
CarloLucibello authored Dec 17, 2024
1 parent bbff8a9 commit 27d13c8
Show file tree
Hide file tree
Showing 10 changed files with 1,019 additions and 618 deletions.
6 changes: 5 additions & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,11 @@ LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l

function (l::GatedGraphConv)(g, x, ps, st)
gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru))
fgru = (x, h) -> gru((x, (h,)))[1] # make the forward compatible with Flux.GRUCell style
# make the forward compatible with Flux.GRUCell style
function fgru(x, h)
y, (h, ) = gru((x, (h,)))
return y, h
end
m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims)
return GNNlib.gated_graph_conv(m, g, x), st
end
Expand Down
3 changes: 1 addition & 2 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,7 @@ function gated_graph_conv(l, g::GNNGraph, x::AbstractMatrix)
for i in 1:(l.num_layers)
m = view(l.weight, :, :, i) * h
m = propagate(copy_xj, g, l.aggr; xj = m)
# in gru forward, hidden state is first argument, input is second
h = l.gru(m, h)
_, h = l.gru(m, h)
end
return h
end
Expand Down
4 changes: 2 additions & 2 deletions GNNlib/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ function set2set_pool(l, g::GNNGraph, x::AbstractMatrix)
qstar = zeros_like(x, (2*n_in, g.num_graphs))
h = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
c = zeros_like(l.lstm.Wh, size(l.lstm.Wh, 2))
state = (h, c)
for t in 1:l.num_iters
h, c = l.lstm(qstar, (h, c)) # [n_in, n_graphs]
q = h
q, state = l.lstm(qstar, state) # [n_in, n_graphs]
qn = broadcast_nodes(g, q) # [n_in, n_nodes]
α = softmax_nodes(g, sum(qn .* x, dims = 1)) # [1, n_nodes]
r = reduce_nodes(+, g, x .* α) # [n_in, n_graphs]
Expand Down
4 changes: 3 additions & 1 deletion GraphNeuralNetworks/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "1.0.0-DEV"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
GNNGraphs = "aed8fd31-079b-4b5a-b342-a13352159b8c"
GNNlib = "a6a84749-d869-43f8-aacc-be26a1996e48"
Expand All @@ -18,7 +19,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Flux = "0.15"
ConcreteStructs = "0.2.3"
Flux = "0.16.0"
GNNGraphs = "1.4"
GNNlib = "1"
LinearAlgebra = "1"
Expand Down
15 changes: 8 additions & 7 deletions GraphNeuralNetworks/src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ module GraphNeuralNetworks
using Statistics: mean
using LinearAlgebra, Random
using Flux
using Flux: glorot_uniform, leakyrelu, GRUCell, batch
using Flux: glorot_uniform, leakyrelu, GRUCell, batch, initialstates
using MacroTools: @forward
using NNlib
using ChainRulesCore
using Reexport: @reexport
using MLUtils: zeros_like
using ConcreteStructs: @concrete

using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T,
check_num_nodes, check_num_edges,
Expand Down Expand Up @@ -49,12 +50,12 @@ include("layers/heteroconv.jl")
export HeteroGraphConv

include("layers/temporalconv.jl")
export TGCN,
A3TGCN,
GConvLSTM,
GConvGRU,
DCGRU,
EvolveGCNO
export GNNRecurrence,
GConvGRU, GConvGRUCell,
GConvLSTM, GConvLSTMCell,
DCGRU, DCGRUCell,
EvolveGCNO, EvolveGCNOCell,
TGCN, TGCNCell

include("layers/pool.jl")
export GlobalPool,
Expand Down
6 changes: 6 additions & 0 deletions GraphNeuralNetworks/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# The implementations of the forward pass of the graph convolutional layers are in the `GNNlib` module,
# in the src/layers/conv.jl file. The `GNNlib` module is re-exported in the GraphNeuralNetworks module.
# This annoying for the readability of the code, as the user has to look at two different files to understand
# the implementation of a single layer,
# but it is done for GraphNeuralNetworks.jl and GNNLux.jl to be able to share the same code.

@doc raw"""
GCNConv(in => out, σ=identity; [bias, init, add_self_loops, use_edge_weight])
Expand Down
6 changes: 0 additions & 6 deletions GraphNeuralNetworks/src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,6 @@ function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
return Set2Set(lstm, n_iters)
end

function initialstates(cell::LSTMCell)
h = zeros_like(cell.Wh, size(cell.Wh, 2))
c = zeros_like(cell.Wh, size(cell.Wh, 2))
return h, c
end

function (l::Set2Set)(g, x)
return GNNlib.set2set_pool(l, g, x)
end
Expand Down
Loading

0 comments on commit 27d13c8

Please sign in to comment.