Skip to content

Commit

Permalink
[GNNLux] Add A3TGCN temporal layer (#485)
Browse files Browse the repository at this point in the history
* Export A3TGCN

* Add struct

* Add test A#TGCN

* Fix test
  • Loading branch information
aurorarossi authored Aug 25, 2024
1 parent ed78e88 commit cb82352
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
1 change: 1 addition & 0 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export AGNNConv,

include("layers/temporalconv.jl")
export TGCN
export A3TGCN

end #module

38 changes: 37 additions & 1 deletion GNNLux/src/layers/temporalconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,40 @@ function Base.show(io::IO, tgcn::TGCNCell)
print(io, "TGCNCell($(tgcn.in_dims) => $(tgcn.out_dims))")
end

TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...))
TGCN(ch::Pair{Int, Int}; kwargs...) = GNNLux.StatefulRecurrentCell(TGCNCell(ch; kwargs...))

@concrete struct A3TGCN <: GNNContainerLayer{(:tgcn, :dense1, :dense2)}
in_dims::Int
out_dims::Int
tgcn
dense1
dense2
end

function A3TGCN(ch::Pair{Int, Int}; use_bias = true, init_weight = glorot_uniform, init_state = zeros32, init_bias = zeros32, add_self_loops = false, use_edge_weight = true)
in_dims, out_dims = ch
tgcn = TGCN(ch; use_bias, init_weight, init_state, init_bias, add_self_loops, use_edge_weight)
dense1 = Dense(out_dims, out_dims)
dense2 = Dense(out_dims, out_dims)
return A3TGCN(in_dims, out_dims, tgcn, dense1, dense2)
end

function (l::A3TGCN)(g, x, ps, st)
dense1 = StatefulLuxLayer{true}(l.dense1, ps.dense1, _getstate(st, :dense1))
dense2 = StatefulLuxLayer{true}(l.dense2, ps.dense2, _getstate(st, :dense2))
h, st = l.tgcn(g, x, ps.tgcn, st.tgcn)
x = dense1(h)
x = dense2(x)
a = NNlib.softmax(x, dims = 3)
c = sum(a .* h , dims = 3)
if length(size(c)) == 3
c = dropdims(c, dims = 3)
end
return c, st
end

LuxCore.outputsize(l::A3TGCN) = (l.out_dims,)

function Base.show(io::IO, l::A3TGCN)
print(io, "A3TGCN($(l.in_dims) => $(l.out_dims))")
end
10 changes: 9 additions & 1 deletion GNNLux/test/layers/temporalconv_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
using LuxTestUtils: test_gradients, AutoReverseDiff, AutoTracker, AutoForwardDiff, AutoEnzyme

rng = StableRNG(1234)
g = rand_graph(10, 40, seed=1234)
g = rand_graph(rng, 10, 40)
x = randn(rng, Float32, 3, 10)

@testset "TGCN" begin
Expand All @@ -12,4 +12,12 @@
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "A3TGCN" begin
l = A3TGCN(3=>3)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end
end

0 comments on commit cb82352

Please sign in to comment.