diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 689ad724b..8dac6eca2 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -33,7 +33,7 @@ export AGNNConv, # GMMConv, GraphConv, MEGNetConv, - # NNConv, + NNConv, # ResGatedGraphConv, # SAGEConv, SGConv diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 04415f1f6..007901a2f 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -669,3 +669,62 @@ function Base.show(io::IO, l::MEGNetConv) print(io, "MEGNetConv(", nin, " => ", nout) print(io, ")") end + +@concrete struct NNConv <: GNNContainerLayer{(:nn,)} + nn <: AbstractLuxLayer + aggr + in_dims::Int + out_dims::Int + use_bias::Bool + init_weight + init_bias + σ +end + +function NNConv(ch::Pair{Int, Int}, nn, σ = identity; + aggr = +, + init_bias = zeros32, + use_bias::Bool = true, + init_weight = glorot_uniform) + in_dims, out_dims = ch + σ = NNlib.fast_act(σ) + return NNConv(nn, aggr, in_dims, out_dims, use_bias, init_weight, init_bias, σ) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::NNConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + ps = (; nn = LuxCore.initialparameters(rng, l.nn), weight) + if l.use_bias + ps = (; ps..., bias = l.init_bias(rng, l.out_dims)) + end + return ps +end + +function LuxCore.initialstates(rng::AbstractRNG, l::NNConv) + return (; nn = LuxCore.initialstates(rng, l.nn)) +end + +function LuxCore.parameterlength(l::NNConv) + n = parameterlength(l.nn) + l.in_dims * l.out_dims + if l.use_bias + n += l.out_dims + end + return n +end + +LuxCore.statelength(l::NNConv) = statelength(l.nn) + +function (l::NNConv)(g, x, e, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps.nn, st.nn) + m = (; nn, l.aggr, ps.weight, bias = _getbias(ps), l.σ) + y = GNNlib.nn_conv(m, g, x, e) + stnew = _getstate(nn) + return y, stnew +end + +function Base.show(io::IO, l::NNConv) + print(io, "NNConv($(l.nn)") + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + print(io, ")") +end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index c0f0d28e3..4151c81e9 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -106,4 +106,26 @@ @test size(x_new) == (out_dims, g.num_nodes) @test size(e_new) == (out_dims, g.num_edges) end + + @testset "NNConv" begin + n_in = 3 + n_in_edge = 10 + n_out = 5 + + s = [1,1,2,3] + t = [2,3,1,1] + g2 = GNNGraph(s, t) + + nn = Dense(n_in_edge => n_out * n_in) + l = NNConv(n_in => n_out, nn, tanh, aggr = +) + x = randn(Float32, n_in, g2.num_nodes) + e = randn(Float32, n_in_edge, g2.num_edges) + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + + y, st′ = l(g2, x, e, ps, st) + + @test size(y) == (n_out, g2.num_nodes) + end end diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index b6b80df49..aaf8a8e03 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -14,7 +14,7 @@ export test_lux_layer function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; outputsize=nothing, sizey=nothing, container=false, - atol=1.0f-2, rtol=1.0f-2) + atol=1.0f-2, rtol=1.0f-2, e=nothing) if container @test l isa GNNContainerLayer @@ -27,7 +27,11 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; @test LuxCore.parameterlength(l) == LuxCore.parameterlength(ps) @test LuxCore.statelength(l) == LuxCore.statelength(st) - y, st′ = l(g, x, ps, st) + if e !== nothing + y, st′ = l(g, x, e, ps, st) + else + y, st′ = l(g, x, ps, st) + end @test eltype(y) == eltype(x) if outputsize !== nothing @test LuxCore.outputsize(l) == outputsize @@ -42,4 +46,4 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) end -end \ No newline at end of file +end