diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 2a9cc5852..0a8c4e290 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -35,7 +35,7 @@ export AGNNConv, MEGNetConv, NNConv, ResGatedGraphConv, - # SAGEConv, + SAGEConv, SGConv # TAGConv, # TransformerConv diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index fbf7ad7c2..f0b51066b 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -844,3 +844,51 @@ function Base.show(io::IO, l::ResGatedGraphConv) l.use_bias || print(io, ", use_bias=false") print(io, ")") end + +@concrete struct SAGEConv <: GNNLayer + in_dims::Int + out_dims::Int + use_bias::Bool + init_weight + init_bias + σ + aggr +end + +function SAGEConv(ch::Pair{Int, Int}, σ = identity; + aggr = mean, + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true) + in_dims, out_dims = ch + σ = NNlib.fast_act(σ) + return SAGEConv(in_dims, out_dims, use_bias, init_weight, init_bias, σ, aggr) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::SAGEConv) + weight = l.init_weight(rng, l.out_dims, 2 * l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end + +LuxCore.parameterlength(l::SAGEConv) = l.use_bias ? l.out_dims * 2 * l.in_dims + l.out_dims : + l.out_dims * 2 * l.in_dims +LuxCore.outputsize(d::SAGEConv) = (d.out_dims,) + +function Base.show(io::IO, l::SAGEConv) + print(io, "SAGEConv(", l.in_dims, " => ", l.out_dims) + (l.σ == identity) || print(io, ", ", l.σ) + (l.aggr == mean) || print(io, ", aggr=", l.aggr) + l.use_bias || print(io, ", use_bias=false") + print(io, ")") +end + +function (l::SAGEConv)(g, x, ps, st) + m = (; ps.weight, bias = _getbias(ps), + l.σ, l.aggr) + return GNNlib.sage_conv(m, g, x), st +end \ No newline at end of file diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 6541dfe0c..4f871b64e 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -134,4 +134,9 @@ l = ResGatedGraphConv(in_dims => out_dims, tanh) test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) end + + @testset "SAGEConv" begin + l = SAGEConv(in_dims => out_dims, tanh) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end end