diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index 332b578c6..bcd243df3 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -26,8 +26,13 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; st = LuxCore.initialstates(rng, l) @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) @test LuxCore.statelength(l) == LuxCore.statelength(st) - - y, st′ = l(g, x, edge_weight, ps, st) + + if edge_weight !== nothing + y, st′ = l(g, x, ps, st) + else + y, st′ = l(g, x, edge_weight, ps, st) + end + @test eltype(y) == eltype(x) if outputsize !== nothing @test LuxCore.outputsize(l) == outputsize @@ -42,4 +47,4 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) end -end \ No newline at end of file +end