From ed78e8831e7f691420e66cc66115135fffab3faf Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Sun, 25 Aug 2024 13:43:31 +0530 Subject: [PATCH] [GNNLux] Adding MegNetConv Layer (#480) * megnet WIP * fix * fix * fix output * wip * temporary changes to run tests * testing * test * test * mean * mean * fix * fix * fix * added edge check * test * fix * Update basic_tests.jl * Update conv_tests.jl: Fixing tests * Update conv.jl: Back to old commit * Update conv_tests.jl: Fix tests * Update conv_tests.jl * Update conv.jl --- GNNLux/src/GNNLux.jl | 3 ++- GNNLux/src/layers/conv.jl | 42 ++++++++++++++++++++++++++++++++ GNNLux/test/layers/conv_tests.jl | 13 ++++++++++ GNNlib/src/layers/conv.jl | 2 +- 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 40a12b25e..cd222ab1c 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -1,6 +1,7 @@ module GNNLux using ConcreteStructs: @concrete using NNlib: NNlib, sigmoid, relu, swish +using Statistics: mean using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize, initialparameters, initialstates, parameterlength, statelength using Lux: Lux, Chain, Dense, GRUCell, @@ -30,7 +31,7 @@ export AGNNConv, GINConv, # GMMConv, GraphConv, - # MEGNetConv, + MEGNetConv, # NNConv, # ResGatedGraphConv, # SAGEConv, diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 83c3efddc..30564ae48 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -628,3 +628,45 @@ function Base.show(io::IO, l::GINConv) print(io, ", $(l.ϵ)") print(io, ")") end + +@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)} + in_dims::Int + out_dims::Int + ϕe::TE + ϕv::TV + aggr::A +end + +function MEGNetConv(in_dims::Int, out_dims::Int, ϕe::TE, ϕv::TV; aggr::A = mean) where {TE, TV, A} + return MEGNetConv{TE, TV, A}(in_dims, out_dims, ϕe, ϕv, aggr) +end + +function MEGNetConv(ch::Pair{Int, Int}; aggr = mean) + nin, nout = ch + ϕe = Chain(Dense(3nin, nout, relu), + Dense(nout, nout)) + + ϕv = Chain(Dense(nin + nout, nout, relu), + Dense(nout, nout)) + + return MEGNetConv(nin, nout, ϕe, ϕv, aggr=aggr) +end + +function (l::MEGNetConv)(g, x, e, ps, st) + ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) + ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv)) + m = (; ϕe, ϕv, aggr=l.aggr) + return GNNlib.megnet_conv(m, g, x, e), st +end + + +LuxCore.outputsize(l::MEGNetConv) = (l.out_dims,) + +(l::MEGNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st) + +function Base.show(io::IO, l::MEGNetConv) + nin = l.in_dims + nout = l.out_dims + print(io, "MEGNetConv(", nin, " => ", nout) + print(io, ")") +end \ No newline at end of file diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index ab06c9445..86a056977 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -93,4 +93,17 @@ l = GINConv(nn, 0.5) test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true) end + + @testset "MEGNetConv" begin + l = MEGNetConv(in_dims => out_dims) + + ps = LuxCore.initialparameters(rng, l) + st = LuxCore.initialstates(rng, l) + + e = randn(rng, Float32, in_dims, g.num_edges) + (x_new, e_new), st_new = l(g, x, e, ps, st) + + @test size(x_new) == (out_dims, g.num_nodes) + @test size(e_new) == (out_dims, g.num_edges) + end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 3a5c543a1..9caa4280f 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -724,4 +724,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix) T1_out = T2_out end return h .+ l.bias -end \ No newline at end of file +end