From 8e6175ba2977eeea7f9d5754c46f0c5f8bb7ad93 Mon Sep 17 00:00:00 2001 From: Aurora Rossi Date: Mon, 9 Sep 2024 14:47:37 +0200 Subject: [PATCH] Fix --- GNNLux/src/layers/temporalconv.jl | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/GNNLux/src/layers/temporalconv.jl b/GNNLux/src/layers/temporalconv.jl index 12c0608eb..09594bf67 100644 --- a/GNNLux/src/layers/temporalconv.jl +++ b/GNNLux/src/layers/temporalconv.jl @@ -274,10 +274,9 @@ LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,) DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...)) -@concrete struct EvolveGCNO <: GNNContainerLayer{(:conv,)} +@concrete struct EvolveGCNO <: GNNLayer in_dims::Int out_dims::Int - conv use_bias::Bool init_weight init_state::Function @@ -286,8 +285,7 @@ end function EvolveGCNO(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32) in_dims, out_dims = ch - conv = GCNConv(ch; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias) - return EvolveGCNO(in_dims, out_dims, conv, use_bias, init_weight, init_state, init_bias) + return EvolveGCNO(in_dims, out_dims, use_bias, init_weight, init_state, init_bias) end function LuxCore.initialparameters(rng::AbstractRNG, l::EvolveGCNO) @@ -319,20 +317,20 @@ function LuxCore.initialstates(rng::AbstractRNG, l::EvolveGCNO) end function (l::EvolveGCNO)(tg::TemporalSnapshotsGNNGraph, x, ps::NamedTuple, st::NamedTuple) - H = st.lstm.h - C = st.lstm.c + H, C = st.lstm W = ps.conv.weight + m = (; ps.conv.weight, bias = _getbias(ps), + add_self_loops =true, use_edge_weight=true, σ = identity) + X = map(1:tg.num_snapshots) do i - F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W + ps.lstm.Uf .* H + ps.lstm.Bf) - I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W + ps.lstm.Ui .* H + ps.lstm.Bi) - O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W + ps.lstm.Uo .* H + ps.lstm.Bo) - C̃ = NNlib.tanh_fast.(ps.lstm.Wc .* W + ps.lstm.Uc .* H + ps.lstm.Bc) + F = NNlib.sigmoid_fast.(ps.lstm.Wf .* W .+ ps.lstm.Uf .* H .+ ps.lstm.Bf) + I = NNlib.sigmoid_fast.(ps.lstm.Wi .* W .+ ps.lstm.Ui .* H .+ ps.lstm.Bi) + O = NNlib.sigmoid_fast.(ps.lstm.Wo .* W .+ ps.lstm.Uo .* H .+ ps.lstm.Bo) + C̃ = NNlib.tanh_fast.(ps.lstm.Wc .* W .+ ps.lstm.Uc .* H .+ ps.lstm.Bc) C = F .* C + I .* C̃ H = O .* NNlib.tanh_fast.(C) W = H - X, _ = l.conv(tg.snapshots[i], x[i], ps.conv, st.conv; conv_weight = H) - X + GNNlib.gcn_conv(m,tg.snapshots[i], x[i], nothing, d -> 1 ./ sqrt.(d), W) end - return X, (conv = (), lstm = (h = H, c = C)) + return X, (; conv = (;), lstm = (h = H, c = C)) end -