Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNLux] add GMMConv, ResGatedGraphConv #494

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ export AGNNConv,
GatedGraphConv,
GCNConv,
GINConv,
# GMMConv,
GMMConv,
GraphConv,
MEGNetConv,
NNConv,
# ResGatedGraphConv,
ResGatedGraphConv,
# SAGEConv,
SGConv
# TAGConv,
Expand Down
118 changes: 117 additions & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,68 @@ function Base.show(io::IO, l::GINConv)
print(io, ")")
end

@concrete struct GMMConv <: GNNLayer
σ
ch::Pair{NTuple{2, Int}, Int}
K::Int
residual::Bool
init_weight
init_bias
use_bias::Bool
dense_x
end

function GMMConv(ch::Pair{NTuple{2, Int}, Int},
σ = identity;
K::Int = 1,
residual = false,
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias = true)
dense_x = Dense(ch[1][1] => ch[2] * K, use_bias = false)
return GMMConv(σ, ch, K, residual, init_weight, init_bias, use_bias, dense_x)
end


function LuxCore.initialparameters(rng::AbstractRNG, l::GMMConv)
ein = l.ch[1][2]
mu = l.init_weight(rng, ein, l.K)
sigma_inv = l.init_weight(rng, ein, l.K)
ps = (; mu, sigma_inv, dense_x = LuxCore.initialparameters(rng, l.dense_x))
if l.use_bias
bias = l.init_bias(rng, l.ch[2])
ps = (; ps..., bias)
end
return ps
end

LuxCore.outputsize(l::GMMConv) = (l.ch[2],)

function LuxCore.parameterlength(l::GMMConv)
n = 2 * l.ch[1][2] * l.K
n += parameterlength(l.dense_x)
if l.use_bias
n += l.ch[2]
end
return n
end

function (l::GMMConv)(g::GNNGraph, x, e, ps, st)
dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x))
m = (; ps.mu, ps.sigma_inv, dense_x, l.σ, l.ch, l.K, l.residual, bias = _getbias(ps))
return GNNlib.gmm_conv(m, g, x, e), st
end

function Base.show(io::IO, l::GMMConv)
(nin, ein), out = l.ch
print(io, "GMMConv((", nin, ",", ein, ")=>", out)
l.σ == identity || print(io, ", σ=", l.dense_s.σ)
print(io, ", K=", l.K)
print(io, ", residual=", l.residual)
l.use_bias == true || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)}
in_dims::Int
out_dims::Int
Expand Down Expand Up @@ -712,6 +774,8 @@ function LuxCore.parameterlength(l::NNConv)
return n
end

LuxCore.outputsize(l::NNConv) = (l.out_dims,)

LuxCore.statelength(l::NNConv) = statelength(l.nn)

function (l::NNConv)(g, x, e, ps, st)
Expand All @@ -723,7 +787,59 @@ function (l::NNConv)(g, x, e, ps, st)
end

function Base.show(io::IO, l::NNConv)
print(io, "NNConv($(l.nn)")
print(io, "NNConv($(l.in_dims) => $(l.out_dims), $(l.nn)")
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct ResGatedGraphConv <: GNNLayer
in_dims::Int
out_dims::Int
σ
init_bias
init_weight
use_bias::Bool
end

function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true)
in_dims, out_dims = ch
return ResGatedGraphConv(in_dims, out_dims, σ, init_bias, init_weight, use_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::ResGatedGraphConv)
A = l.init_weight(rng, l.out_dims, l.in_dims)
B = l.init_weight(rng, l.out_dims, l.in_dims)
U = l.init_weight(rng, l.out_dims, l.in_dims)
V = l.init_weight(rng, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; A, B, U, V, bias)
else
return (; A, B, U, V)
end
end

function LuxCore.parameterlength(l::ResGatedGraphConv)
n = 4 * l.in_dims * l.out_dims
if l.use_bias
n += l.out_dims
end
return n
end

LuxCore.outputsize(l::ResGatedGraphConv) = (l.out_dims,)

function (l::ResGatedGraphConv)(g, x, ps, st)
m = (; ps.A, ps.B, ps.U, ps.V, bias = _getbias(ps), l.σ)
return GNNlib.res_gated_graph_conv(m, g, x), st
end

function Base.show(io::IO, l::ResGatedGraphConv)
print(io, "ResGatedGraphConv(", l.in_dims, " => ", l.out_dims)
l.σ == identity || print(io, ", ", l.σ)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
Expand Down
18 changes: 12 additions & 6 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,18 @@
l = NNConv(n_in => n_out, nn, tanh, aggr = +)
x = randn(Float32, n_in, g2.num_nodes)
e = randn(Float32, n_in_edge, g2.num_edges)
test_lux_layer(rng, l, g2, x; outputsize=(n_out,), e, container=true)
end

ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
@testset "GMMConv" begin
ein_dims = 4
e = randn(rng, Float32, ein_dims, g.num_edges)
l = GMMConv((in_dims, ein_dims) => out_dims, tanh; K = 2, residual = false)
test_lux_layer(rng, l, g, x; outputsize=(out_dims,), e)
end

y, st′ = l(g2, x, e, ps, st)

@test size(y) == (n_out, g2.num_nodes)
end
@testset "ResGatedGraphConv" begin
l = ResGatedGraphConv(in_dims => out_dims, tanh)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end
end
6 changes: 5 additions & 1 deletion GNNLux/test/shared_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x;
@test size(y) == (outputsize..., g.num_nodes)
end

loss = (x, ps) -> sum(first(l(g, x, ps, st)))
if e !== nothing
loss = (x, ps) -> sum(first(l(g, x, e, ps, st)))
else
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
end
test_gradients(loss, x, ps; atol, rtol, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoForwardDiff(), AutoEnzyme()])
end

Expand Down
2 changes: 1 addition & 1 deletion GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ function gmm_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
m = propagate(e_mul_xj, g, mean, xj = xj, e = w)
m = dropdims(mean(m, dims = 2), dims = 2) # (out, num_nodes)

m = l.σ(m .+ l.bias)
m = l.σ.(m .+ l.bias)

if l.residual
if size(x, 1) == size(m, 1)
Expand Down
6 changes: 4 additions & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,9 @@ end
function Base.show(io::IO, l::NNConv)
out, in = size(l.weight)
print(io, "NNConv($in => $out")
print(io, ", aggr=", l.aggr)
print(io, ", ", l.nn)
l.σ == identity || print(io, ", ", l.σ)
(l.aggr == +) || print(io, "; aggr=", l.aggr)
print(io, ")")
end

Expand Down Expand Up @@ -1136,7 +1138,7 @@ function Base.show(io::IO, l::GMMConv)
print(io, "GMMConv((", nin, ",", ein, ")=>", out)
l.σ == identity || print(io, ", σ=", l.dense_s.σ)
print(io, ", K=", l.K)
l.residual == true || print(io, ", residual=", l.residual)
print(io, ", residual=", l.residual)
print(io, ")")
end

Expand Down
Loading