Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Sep 9, 2024
1 parent d49bde7 commit 8e6175b
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,9 @@ LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,)

DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...))

@concrete struct EvolveGCNO <: GNNContainerLayer{(:conv,)}
@concrete struct EvolveGCNO <: GNNLayer
in_dims::Int
out_dims::Int
conv
use_bias::Bool
init_weight
init_state::Function
Expand All @@ -286,8 +285,7 @@ end

function EvolveGCNO(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
in_dims, out_dims = ch
conv = GCNConv(ch; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
return EvolveGCNO(in_dims, out_dims, conv, use_bias, init_weight, init_state, init_bias)
return EvolveGCNO(in_dims, out_dims, use_bias, init_weight, init_state, init_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::EvolveGCNO)
Expand Down Expand Up @@ -319,20 +317,20 @@ function LuxCore.initialstates(rng::AbstractRNG, l::EvolveGCNO)
end

function (l::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x, ps::NamedTuple, st::NamedTuple)
H = st.lstm.h
C = st.lstm.c
H, C = st.lstm
W = ps.conv.weight
m = (; ps.conv.weight, bias = _getbias(ps),
add_self_loops =true, use_edge_weight=true, σ = identity)

X = map(1:tg.num_snapshots) do i
F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W + ps.lstm.Uf .* H + ps.lstm.Bf)
I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W + ps.lstm.Ui .* H + ps.lstm.Bi)
O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W + ps.lstm.Uo .* H + ps.lstm.Bo)
= NNlib.tanh_fast.(ps.lstm.Wc .* W + ps.lstm.Uc .* H + ps.lstm.Bc)
F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W .+ ps.lstm.Uf .* H .+ ps.lstm.Bf)
I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W .+ ps.lstm.Ui .* H .+ ps.lstm.Bi)
O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W .+ ps.lstm.Uo .* H .+ ps.lstm.Bo)
= NNlib.tanh_fast.(ps.lstm.Wc .* W .+ ps.lstm.Uc .* H .+ ps.lstm.Bc)
C = F .* C + I .*
H = O .* NNlib.tanh_fast.(C)
W = H
X, _ = l.conv(tg.snapshots[i], x[i], ps.conv, st.conv; conv_weight = H)
X
GNNlib.gcn_conv(m,tg.snapshots[i], x[i], nothing, d -> 1 ./ sqrt.(d), W)
end
return X, (conv = (), lstm = (h = H, c = C))
return X, (; conv = (;), lstm = (h = H, c = C))
end

0 comments on commit 8e6175b

Please sign in to comment.