Skip to content

Commit

Permalink
Add GConvGRU, GConvLSTM and DCGRU
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Aug 28, 2024
1 parent d6d5f08 commit 1ee29d9
Showing 1 changed file with 181 additions and 0 deletions.
181 changes: 181 additions & 0 deletions GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,184 @@ LuxCore.outputsize(l::A3TGCN) = (l.out_dims,)
function Base.show(io::IO, l::A3TGCN)
print(io, "A3TGCN($(l.in_dims) => $(l.out_dims))")
end

@concrete struct GConvGRUCell <: GNNContainerLayer{(:conv_x_r, :conv_h_r, :conv_x_z, :conv_h_z, :conv_x_h, :conv_h_h)}
in_dims::Int
out_dims::Int
k::Int
conv_x_r
conv_h_r
conv_x_z
conv_h_z
conv_x_h
conv_h_h
init_state::Function
end

function GConvGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
in_dims, out_dims = ch
#reset gate
conv_x_r = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_r = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
#update gate
conv_x_z = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_z = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
#hidden state
conv_x_h = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_h = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
return GConvGRUCell(in_dims, out_dims, k, conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, init_state)
end

function (l::GConvGRUCell)(g, (x, h), ps, st)
if h === nothing
h = l.init_state(l.out_dims, g.num_nodes)
end
xr, st_conv_xr = l.conv_x_r(g, x, ps.conv_x_r, st.conv_x_r)
hr, st_conv_hr = l.conv_h_r(g, h, ps.conv_h_r, st.conv_h_r)
r = xr .+ hr
r = NNlib.sigmoid_fast(r)
xz, st_conv_x_z = l.conv_x_z(g, x, ps.conv_x_z, st.conv_x_z)
hz, st_conv_h_z = l.conv_h_z(g, h, ps.conv_h_z, st.conv_h_z)
z = xz .+ hz
z = NNlib.sigmoid_fast(z)
xh, st_conv_x_h = l.conv_x_h(g, x, ps.conv_x_h, st.conv_x_h)
hh, st_conv_h_h = l.conv_h_h(g, r .* h, ps.conv_h_h, st.conv_h_h)
= xh .+ hh
= NNlib.tanh_fast(h)
h = (1 .- z).*+ z.* h
return (h, h), (conv_x_r = st_conv_xr, conv_h_r = st_conv_hr, conv_x_z = st_conv_x_z, conv_h_z = st_conv_h_z, conv_x_h = st_conv_x_h, conv_h_h = st_conv_h_h)
end

function Base.show(io::IO, l::GConvGRUCell)
print(io, "GConvGRUCell($(l.in_dims) => $(l.out_dims))")
end

LuxCore.outputsize(l::GConvGRUCell) = (l.out_dims,)

GConvGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(GConvGRUCell(ch, k; kwargs...))

@concrete struct GConvLSTMCell <: GNNContainerLayer{(:conv_x_i, :conv_h_i, :dense_i, :conv_x_f, :conv_h_f, :dense_f, :conv_x_c, :conv_h_c, :dense_c, :conv_x_o, :conv_h_o, :dense_o)}
in_dims::Int
out_dims::Int
k::Int
conv_x_i
conv_h_i
dense_i
conv_x_f
conv_h_f
dense_f
conv_x_c
conv_h_c
dense_c
conv_x_o
conv_h_o
dense_o
init_state::Function
end

function GConvLSTMCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
in_dims, out_dims = ch
#input gate
conv_x_i = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_i = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
dense_i = Dense(out_dims, 1; use_bias, init_weight, init_bias)
#forget gate
conv_x_f = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_f = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
dense_f = Dense(out_dims, 1; use_bias, init_weight, init_bias)
#cell gate
conv_x_c = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_c = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
dense_c = Dense(out_dims, 1; use_bias, init_weight, init_bias)
#output gate
conv_x_o = ChebConv(in_dims => out_dims, k; use_bias, init_weight, init_bias)
conv_h_o = ChebConv(out_dims => out_dims, k; use_bias, init_weight, init_bias)
dense_o = Dense(out_dims, 1; use_bias, init_weight, init_bias)
return GConvLSTMCell(in_dims, out_dims, k, conv_x_i, conv_h_i, dense_i, conv_x_f, conv_h_f, dense_f, conv_x_c, conv_h_c, dense_c, conv_x_o, conv_h_o, dense_o, init_state)
end

function (l::GConvLSTMCell)(g, (x, m), ps, st)
if m === nothing
h = l.init_state(l.out_dims, g.num_nodes)
c = l.init_state(l.out_dims, g.num_nodes)
else
h, c = m
end

dense_i = StatefulLuxLayer{true}(l.dense_i, ps.dense_i, _getstate(st, :dense_i))
dense_f = StatefulLuxLayer{true}(l.dense_f, ps.dense_f, _getstate(st, :dense_f))
dense_c = StatefulLuxLayer{true}(l.dense_c, ps.dense_c, _getstate(st, :dense_c))
dense_o = StatefulLuxLayer{true}(l.dense_o, ps.dense_o, _getstate(st, :dense_o))

xi, st_conv_x_i = l.conv_x_i(g, x, ps.conv_x_i, st.conv_x_i)
hi, st_conv_h_i = l.conv_h_i(g, h, ps.conv_h_i, st.conv_h_i)
i = xi .+ hi .+ dense_i(c)
i = NNlib.sigmoid_fast(i)

xf, st_conv_x_f = l.conv_x_f(g, x, ps.conv_x_f, st.conv_x_f)
hf, st_conv_h_f = l.conv_h_f(g, h, ps.conv_h_f, st.conv_h_f)
f = xf .+ hf .+ dense_f(c)
f = NNlib.sigmoid_fast(f)

xc, st_conv_x_c = l.conv_x_c(g, x, ps.conv_x_c, st.conv_x_c)
hc, st_conv_h_c = l.conv_h_c(g, h, ps.conv_h_c, st.conv_h_c)
c = f .* c + i.* NNlib.tanh_fast(xc .+ hc .+ dense_c(c))

xo, st_conv_x_o = l.conv_x_o(g, x, ps.conv_x_o, st.conv_x_o)
ho, st_conv_h_o = l.conv_h_o(g, h, ps.conv_h_o, st.conv_h_o)
o = xo .+ ho .+ dense_o(c)
o = NNlib.sigmoid_fast(o)
h = o.* NNlib.tanh_fast(c)
return (h, (h, c)), (conv_x_i = st_conv_x_i, conv_h_i = st_conv_h_i, conv_x_f = st_conv_x_f, conv_h_f = st_conv_h_f, conv_x_c = st_conv_x_c, conv_h_c = st_conv_h_c, conv_x_o = st_conv_x_o, conv_h_o = st_conv_h_o)
end

function Base.show(io::IO, l::GConvLSTMCell)
print(io, "GConvLSTMCell($(l.in_dims) => $(l.out_dims))")
end

LuxCore.outputsize(l::GConvLSTMCell) = (l.out_dims,)

GConvLSTM(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(GConvLSTMCell(ch, k; kwargs...))

@concrete struct DCGRUCell <: GNNContainerLayer{(:dconv_u, :dconv_r, :dconv_c)}
in_dims::Int
out_dims::Int
k::Int
dconv_u
dconv_r
dconv_c
init_state::Function
end

function DCGRUCell(ch::Pair{Int, Int}, k::Int; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32)
in_dims, out_dims = ch
dconv_u = DConv((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
dconv_r = DConv((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
dconv_c = DConv((in_dims + out_dims) => out_dims, k; use_bias = use_bias, init_weight = init_weight, init_bias = init_bias)
return DCGRUCell(in_dims, out_dims, k, dconv_u, dconv_r, dconv_c, init_state)
end

function (l::DCGRUCell)(g, (x, h), ps, st)
if h === nothing
h = l.init_state(l.out_dims, g.num_nodes)
end
= vcat(x, h)
z, st_dconv_u = l.dconv_u(g, h̃, ps.dconv_u, st.dconv_u)
z = NNlib.sigmoid_fast.(z)
r, st_dconv_r = l.dconv_r(g, h̃, ps.dconv_r, st.dconv_r)
r = NNlib.sigmoid_fast.(r)
= vcat(x, h .* r)
c, st_dconv_c = l.dconv_c(g, ĥ, ps.dconv_c, st.dconv_c)
c = NNlib.tanh_fast.(c)
h = z.* h + (1 .- z).* c
return (h, h), (dconv_u = st_dconv_u, dconv_r = st_dconv_r, dconv_c = st_dconv_c)
end

function Base.show(io::IO, l::DCGRUCell)
print(io, "DCGRUCell($(l.in_dims) => $(l.out_dims))")
end

LuxCore.outputsize(l::DCGRUCell) = (l.out_dims,)

DCGRU(ch::Pair{Int, Int}, k::Int; kwargs...) = GNNLux.StatefulRecurrentCell(DCGRUCell(ch, k; kwargs...))

0 comments on commit 1ee29d9

Please sign in to comment.