Skip to content

Commit

Permalink
Improve EvolveGCNO
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Sep 5, 2024
1 parent 5919956 commit 4ea56b1
Showing 1 changed file with 79 additions and 11 deletions.
90 changes: 79 additions & 11 deletions src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,59 @@ Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)

"""
EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
Evolving Graph Convolutional Network (EvolveGCNO) layer from the paper [EvolveGCN: Evolving Graph Convolutional Networks for Dynamic Graphs](https://arxiv.org/pdf/1902.10191).
Perfoms a Graph Convolutional layer with parameters derived from a Long Short-Term Memory (LSTM) layer across the snapshots of the temporal graph.
# Arguments
- `in`: Number of input features.
- `out`: Number of output features.
- `bias`: Add learnable bias. Default `true`.
- `init`: Weights' initializer. Default `glorot_uniform`.
- `init_state`: Initial state of the hidden stat of the LSTM layer. Default `zeros32`.
# Examples
```jldoctest
julia> tg = TemporalSnapshotsGNNGraph([rand_graph(10,20; ndata = rand(4,10)), rand_graph(10,14; ndata = rand(4,10)), rand_graph(10,22; ndata = rand(4,10))])
TemporalSnapshotsGNNGraph:
num_nodes: [10, 10, 10]
num_edges: [20, 14, 22]
num_snapshots: 3
julia> ev = EvolveGCNO(4 => 5)
EvolveGCNO(4 => 5)
julia> size(ev(tg, tg.ndata.x))
(3,)
julia> size(ev(tg, tg.ndata.x)[1])
(5, 10)
```
"""
struct EvolveGCNO
conv
lstm
W_init
init_state
in::Int
out::Int
Wf
Uf
Bf
Wi
Ui
Bi
Wo
Uo
Bo
Wc
Uc
Bc
end

Flux.@functor EvolveGCNO
Expand All @@ -499,20 +545,42 @@ function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.ze
in, out = ch
W = init(out, in)
conv = GCNConv(ch; bias = bias, init = init)
lstm = Flux.LSTM(out,out)
return EvolveGCNO(conv, lstm, W, init_state, in, out)
end

function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph)
H = egcno.init_state(egcno.out, tg.snapshots[i].num_nodes, tg.num_snapshots)
Wf = init(out, in)
Uf = init(out, in)
Bf = bias ? init(out, in) : nothing
Wi = init(out, in)
Ui = init(out, in)
Bi = bias ? init(out, in) : nothing
Wo = init(out, in)
Uo = init(out, in)
Bo = bias ? init(out, in) : nothing
Wc = init(out, in)
Uc = init(out, in)
Bc = bias ? init(out, in) : nothing
return EvolveGCNO(conv, W, init_state, in, out, Wf, Uf, Bf, Wi, Ui, Bi, Wo, Uo, Bo, Wc, Uc, Bc)
end

function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x)
X = egcno.init_state(egcno.out, tg.snapshots[1].num_nodes, tg.num_snapshots)
H = egcno.init_state(egcno.out, egcno.in)
C = egcno.init_state(egcno.out, egcno.in)
W = egcno.W_init
for i in 1:tg.num_snapshots
W = egcno.lstm(W)
H[:,:,i] .= egcno.conv(tg.snapshots[i], tg.ndata.x[i]; conv_weight = W)
X = map(1:tg.num_snapshots) do i
F = Flux.sigmoid_fast.(egcno.Wf .* W + egcno.Uf .* H + egcno.Bf)
I = Flux.sigmoid_fast.(egcno.Wi .* W + egcno.Ui .* H + egcno.Bi)
O = Flux.sigmoid_fast.(egcno.Wo .* W + egcno.Uo .* H + egcno.Bo)
= Flux.tanh.(egcno.Wc .* W + egcno.Uc .* H + egcno.Bc)
C = F .* C + I .*
H = O .* tanh_fast.(C)
W = H
egcno.conv(tg.snapshots[i], x[i]; conv_weight = H)
end
return H
return X
end

function Base.show(io::IO, egcno::EvolveGCNO)
print(io, "EvolveGCNO($(egcno.in) => $(egcno.out))")
end

function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
return l.(tg.snapshots, x)
Expand Down

0 comments on commit 4ea56b1

Please sign in to comment.