Skip to content

Commit

Permalink
Add GConvGRU, GConvLSTM and DCGRU tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aurorarossi committed Aug 28, 2024
1 parent 1ee29d9 commit a1bfcf6
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions GNNLux/test/layers/temporalconv_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,28 @@
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "GConvGRU" begin
l = GConvGRU(3=>3, 2)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "GConvLSTM" begin
l = GConvLSTM(3=>3, 2)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

@testset "DCGRU" begin
l = DCGRU(3=>3, 2)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
test_gradients(loss, x, ps; atol=1.0f-2, rtol=1.0f-2, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end
end

0 comments on commit a1bfcf6

Please sign in to comment.