Skip to content

Commit

Permalink
First draft
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Sep 4, 2024
1 parent bd5e2f2 commit 9d2346a
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,36 @@ Flux.Recur(dcgru::DCGRUCell) = Flux.Recur(dcgru, dcgru.state0)
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph, x) = l(g, x)
_applylayer(l::Flux.Recur{DCGRUCell}, g::GNNGraph) = l(g)

struct EvolveGCNO
conv
lstm
W_init
init_state
in::Int
out::Int
end

Flux.@functor EvolveGCNO

function EvolveGCNO(ch; bias = true, init = glorot_uniform, init_state = Flux.zeros32)
in, out = ch
W = init(out, in)
conv = GCNConv(ch; bias = bias, init = init)
lstm = Flux.LSTM(out,out)
return EvolveGCNO(conv, lstm, W, init_state, in, out)
end

function (egcno::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph)
H = egcno.init_state(egcno.out, tg.snapshots[i].num_nodes, tg.num_snapshots)
W = egcno.W_init
for i in 1:tg.num_snapshots
W = egcno.lstm(W)
H[:,:,i] .= egcno.conv(tg.snapshots[i], tg.ndata.x[i]; conv_weight = W)
end
return H
end


function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector)
return l.(tg.snapshots, x)
end
Expand Down

0 comments on commit 9d2346a

Please sign in to comment.