diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 301210e7a..4151c81e9 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -124,6 +124,8 @@ ps = LuxCore.initialparameters(rng, l) st = LuxCore.initialstates(rng, l) - test_lux_layer(rng, l, g2, x, sizey=(n_out, g2.num_nodes), container=true, edge_weight=e) + y, st′ = l(g2, x, e, ps, st) + + @test size(y) == (n_out, g2.num_nodes) end end