From c7eb715988d7f1ef287ea68b6e073f213dcd3da1 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 1 Aug 2024 18:28:19 +0200 Subject: [PATCH] more layers --- GNNLux/src/GNNLux.jl | 6 +++-- GNNLux/src/layers/conv.jl | 45 +++++++++++++++++++++++--------- GNNLux/test/layers/conv_tests.jl | 12 ++++++++- GNNLux/test/shared_testsetup.jl | 1 + 4 files changed, 48 insertions(+), 16 deletions(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 848549069..3cfca11c3 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -3,7 +3,9 @@ using ConcreteStructs: @concrete using NNlib: NNlib, sigmoid, relu, swish using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize, initialparameters, initialstates, parameterlength, statelength -using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer +using Lux: Lux, Chain, Dense, GRUCell, + glorot_uniform, zeros32, + StatefulLuxLayer using Reexport: @reexport using Random: AbstractRNG using GNNlib: GNNlib @@ -25,7 +27,7 @@ export AGNNConv, GATv2Conv, GatedGraphConv, GCNConv, - # GINConv, + GINConv, # GMMConv, GraphConv # MEGNetConv, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 116daac96..872a3569c 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -38,7 +38,6 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv) end LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims -LuxCore.statelength(d::GCNConv) = 0 LuxCore.outputsize(d::GCNConv) = (d.out_dims,) function Base.show(io::IO, l::GCNConv) @@ -518,7 +517,7 @@ function Base.show(io::IO, l::GATv2Conv) end -@concrete struct GatedGraphConv <: GRULayer +@concrete struct GatedGraphConv <: GNNLayer gru init_weight dims::Int @@ -533,28 +532,48 @@ function GatedGraphConv(dims::Int, num_layers::Int; return GatedGraphConv(gru, init_weight, dims, num_layers, aggr) end -LucCore.outputsize(l::GatedGraphConv) = (l.dims,) +LuxCore.outputsize(l::GatedGraphConv) = (l.dims,) function LuxCore.initialparameters(rng::AbstractRNG, l::GatedGraphConv) gru = LuxCore.initialparameters(rng, l.gru) - weight = l.init_weight(rng, l.dims, l.dims) + weight = l.init_weight(rng, l.dims, l.dims, l.num_layers) return (; gru, weight) end -LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2 +LuxCore.parameterlength(l::GatedGraphConv) = parameterlength(l.gru) + l.dims^2*l.num_layers -function LuxCore.initialstates(rng::AbstractRNG, l::GatedGraphConv) - return (; gru = LuxCore.initialstates(rng, l.gru)) -end - -LuxCore.statelength(l::GatedGraphConv) = statelength(l.gru) -function (l::GatedGraphConv)(g, H, ps, st) - GNNlib.gated_graph_conv(l, g, H) +function (l::GatedGraphConv)(g, x, ps, st) + gru = StatefulLuxLayer{true}(l.gru, ps.gru, _getstate(st, :gru)) + fgru = (h, x) -> gru((x, (h,))) # make the forward compatible with Flux.GRUCell style + m = (; gru=fgru, ps.weight, l.num_layers, l.aggr, l.dims) + return GNNlib.gated_graph_conv(m, g, x), st end function Base.show(io::IO, l::GatedGraphConv) print(io, "GatedGraphConv($(l.dims), $(l.num_layers)") print(io, ", aggr=", l.aggr) print(io, ")") -end \ No newline at end of file +end + +@concrete struct GINConv <: GNNContainerLayer{(:nn,)} + nn <: AbstractExplicitLayer + ϵ <: Real + aggr +end + +GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr) + +function (l::GINConv)(g, x, ps, st) + nn = StatefulLuxLayer{true}(l.nn, ps, st) + m = (; nn, l.ϵ, l.aggr) + y = GNNlib.gin_conv(m, g, x) + stnew = _getstate(nn) + return y, stnew +end + +function Base.show(io::IO, l::GINConv) + print(io, "GINConv($(l.nn)") + print(io, ", $(l.ϵ)") + print(io, ")") +end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index b2e81173d..443230830 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -77,5 +77,15 @@ #TODO test edge end -end + @testset "GatedGraphConv" begin + l = GatedGraphConv(in_dims, 3) + test_lux_layer(rng, l, g, x, outputsize=(in_dims,)) + end + + @testset "GINConv" begin + nn = Chain(Dense(in_dims => out_dims, relu), Dense(out_dims => out_dims)) + l = GINConv(nn, 0.5) + test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) + end +end diff --git a/GNNLux/test/shared_testsetup.jl b/GNNLux/test/shared_testsetup.jl index 1354ef387..b6b80df49 100644 --- a/GNNLux/test/shared_testsetup.jl +++ b/GNNLux/test/shared_testsetup.jl @@ -28,6 +28,7 @@ function test_lux_layer(rng::AbstractRNG, l, g::GNNGraph, x; @test LuxCore.statelength(l) == LuxCore.statelength(st) y, st′ = l(g, x, ps, st) + @test eltype(y) == eltype(x) if outputsize !== nothing @test LuxCore.outputsize(l) == outputsize end