From 3efa8e082b5fc46fb0be1517caeb9cf6ef503553 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Mon, 23 Sep 2024 10:43:59 +0200 Subject: [PATCH] [GNNLux] add GMMConv, ResGatedGraphConv --- GNNLux/src/GNNLux.jl | 4 +- GNNLux/src/layers/conv.jl | 118 ++++++++++++++++++++++++++++++- GNNLux/test/layers/conv_tests.jl | 18 +++-- GNNLux/test/shared_testsetup.jl | 6 +- GNNlib/src/layers/conv.jl | 2 +- src/layers/conv.jl | 6 +- 6 files changed, 141 insertions(+), 13 deletions(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 4a72d8d33..2a9cc5852 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -30,11 +30,11 @@ export AGNNConv, GatedGraphConv, GCNConv, GINConv, - # GMMConv, + GMMConv, GraphConv, MEGNetConv, NNConv, - # ResGatedGraphConv, + ResGatedGraphConv, # SAGEConv, SGConv # TAGConv, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 007901a2f..fbf7ad7c2 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -628,6 +628,68 @@ function Base.show(io::IO, l::GINConv) print(io, ")") end +@concrete struct GMMConv <: GNNLayer + σ + ch::Pair{NTuple{2, Int}, Int} + K::Int + residual::Bool + init_weight + init_bias + use_bias::Bool + dense_x +end + +function GMMConv(ch::Pair{NTuple{2, Int}, Int}, + σ = identity; + K::Int = 1, + residual = false, + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias = true) + dense_x = Dense(ch[1][1] => ch[2] * K, use_bias = false) + return GMMConv(σ, ch, K, residual, init_weight, init_bias, use_bias, dense_x) +end + + +function LuxCore.initialparameters(rng::AbstractRNG, l::GMMConv) + ein = l.ch[1][2] + mu = l.init_weight(rng, ein, l.K) + sigma_inv = l.init_weight(rng, ein, l.K) + ps = (; mu, sigma_inv, dense_x = LuxCore.initialparameters(rng, l.dense_x)) + if l.use_bias + bias = l.init_bias(rng, l.ch[2]) + ps = (; ps..., bias) + end + return ps +end + +LuxCore.outputsize(l::GMMConv) = (l.ch[2],) + +function LuxCore.parameterlength(l::GMMConv) + n = 2 * l.ch[1][2] * l.K + n += parameterlength(l.dense_x) + if l.use_bias + n += l.ch[2] + end + return n +end + +function (l::GMMConv)(g::GNNGraph, x, e, ps, st) + dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x)) + m = (; ps.mu, ps.sigma_inv, dense_x, l.σ, l.ch, l.K, l.residual, bias = _getbias(ps)) + return GNNlib.gmm_conv(m, g, x, e), st +end + +function Base.show(io::IO, l::GMMConv) + (nin, ein), out = l.ch + print(io, "GMMConv((", nin, ",", ein, ")=>", out) + l.σ == identity || print(io, ", σ=", l.dense_s.σ) + print(io, ", K=", l.K) + print(io, ", residual=", l.residual) + l.use_bias == true || print(io, ", use_bias=false") + print(io, ")") +end + @concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)} in_dims::Int out_dims::Int @@ -712,6 +774,8 @@ function LuxCore.parameterlength(l::NNConv) return n end +LuxCore.outputsize(l::NNConv) = (l.out_dims,) + LuxCore.statelength(l::NNConv) = statelength(l.nn) function (l::NNConv)(g, x, e, ps, st) @@ -723,7 +787,59 @@ function (l::NNConv)(g, x, e, ps, st) end function Base.show(io::IO, l::NNConv) - print(io, "NNConv($(l.nn)") + print(io, "NNConv($(l.in_dims) => $(l.out_dims), $(l.nn)") + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + print(io, ")") +end + +@concrete struct ResGatedGraphConv <: GNNLayer + in_dims::Int + out_dims::Int + σ + init_bias + init_weight + use_bias::Bool +end + +function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true) + in_dims, out_dims = ch + return ResGatedGraphConv(in_dims, out_dims, σ, init_bias, init_weight, use_bias) +end + +function LuxCore.initialparameters(rng::AbstractRNG, l::ResGatedGraphConv) + A = l.init_weight(rng, l.out_dims, l.in_dims) + B = l.init_weight(rng, l.out_dims, l.in_dims) + U = l.init_weight(rng, l.out_dims, l.in_dims) + V = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; A, B, U, V, bias) + else + return (; A, B, U, V) + end +end + +function LuxCore.parameterlength(l::ResGatedGraphConv) + n = 4 * l.in_dims * l.out_dims + if l.use_bias + n += l.out_dims + end + return n +end + +LuxCore.outputsize(l::ResGatedGraphConv) = (l.out_dims,) + +function (l::ResGatedGraphConv)(g, x, ps, st) + m = (; ps.A, ps.B, ps.U, ps.V, bias = _getbias(ps), l.σ) + return GNNlib.res_gated_graph_conv(m, g, x), st +end + +function Base.show(io::IO, l::ResGatedGraphConv) + print(io, "ResGatedGraphConv(", l.in_dims, " => ", l.out_dims) l.σ == identity || print(io, ", ", l.σ) l.use_bias || print(io, ", use_bias=false") print(io, ")") diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 4151c81e9..6541dfe0c 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -120,12 +120,18 @@ l = NNConv(n_in => n_out, nn, tanh, aggr = +) x = randn(Float32, n_in, g2.num_nodes) e = randn(Float32, n_in_edge, g2.num_edges) + test_lux_layer(rng, l, g2, x; outputsize=(n_out,), e, container=true) + end - ps = LuxCore.initialparameters(rng, l) - st = LuxCore.initialstates(rng, l) + @testset "GMMConv" begin + ein_dims = 4 + e = randn(rng, Float32, ein_dims, g.num_edges) + l = GMMConv((in_dims, ein_dims) => out_dims, tanh; K = 2, residual = false) + test_lux_layer(rng, l, g, x; outputsize=(out_dims,), e) + end - y, st′ = l(g2, x, e, ps, st) - - @test size(y) == (n_out, g2.num_nodes) - end + @testset "ResGatedGraphConv" begin + l = ResGatedGraphConv(in_dims => out_dims, tanh) + test_lux_layer(rng, l, g, x, outputsize=(out_dims,)) + end end diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index aaf8a8e03..bf2bdbbf2 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -42,7 +42,11 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; @test size(y) == (outputsize..., g.num_nodes) end - loss = (x, ps) -> sum(first(l(g, x, ps, st))) + if e !== nothing + loss = (x, ps) -> sum(first(l(g, x, e, ps, st))) + else + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + end test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()]) end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 9caa4280f..e310fa81c 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -389,7 +389,7 @@ function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) m = propagate(e_mul_xj, g, mean, xj = xj, e = w) m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes) - m = l.σ(m .+ l.bias) + m = l.σ.(m .+ l.bias) if l.residual if size(x, 1) == size(m, 1) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 13c4f3030..2cdab6a4d 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -717,7 +717,9 @@ end function Base.show(io::IO, l::NNConv) out, in = size(l.weight) print(io, "NNConv($in => $out") - print(io, ", aggr=", l.aggr) + print(io, ", ", l.nn) + l.σ == identity || print(io, ", ", l.σ) + (l.aggr == +) || print(io, "; aggr=", l.aggr) print(io, ")") end @@ -1136,7 +1138,7 @@ function Base.show(io::IO, l::GMMConv) print(io, "GMMConv((", nin, ",", ein, ")=>", out) l.σ == identity || print(io, ", σ=", l.dense_s.σ) print(io, ", K=", l.K) - l.residual == true || print(io, ", residual=", l.residual) + print(io, ", residual=", l.residual) print(io, ")") end