Skip to content

Commit

Permalink
[GNNLux] Adding MegNetConv Layer (#480)
Browse files Browse the repository at this point in the history
* megnet WIP

* fix

* fix

* fix output

* wip

* temporary changes to run tests

* testing

* test

* test

* mean

* mean

* fix

* fix

* fix

* added edge check

* test

* fix

* Update basic_tests.jl

* Update conv_tests.jl: Fixing tests

* Update conv.jl: Back to old commit

* Update conv_tests.jl: Fix tests

* Update conv_tests.jl

* Update conv.jl
  • Loading branch information
rbSparky authored Aug 25, 2024
1 parent 9e9ba9d commit ed78e88
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 2 deletions.
3 changes: 2 additions & 1 deletion GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu, swish
using Statistics: mean
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer, parameterlength, statelength, outputsize,
initialparameters, initialstates, parameterlength, statelength
using Lux: Lux, Chain, Dense, GRUCell,
Expand Down Expand Up @@ -30,7 +31,7 @@ export AGNNConv,
GINConv,
# GMMConv,
GraphConv,
# MEGNetConv,
MEGNetConv,
# NNConv,
# ResGatedGraphConv,
# SAGEConv,
Expand Down
42 changes: 42 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,3 +628,45 @@ function Base.show(io::IO, l::GINConv)
print(io, ", $(l.ϵ)")
print(io, ")")
end

@concrete struct MEGNetConv{TE, TV, A} <: GNNContainerLayer{(:ϕe, :ϕv)}
in_dims::Int
out_dims::Int
ϕe::TE
ϕv::TV
aggr::A
end

function MEGNetConv(in_dims::Int, out_dims::Int, ϕe::TE, ϕv::TV; aggr::A = mean) where {TE, TV, A}
return MEGNetConv{TE, TV, A}(in_dims, out_dims, ϕe, ϕv, aggr)
end

function MEGNetConv(ch::Pair{Int, Int}; aggr = mean)
nin, nout = ch
ϕe = Chain(Dense(3nin, nout, relu),
Dense(nout, nout))

ϕv = Chain(Dense(nin + nout, nout, relu),
Dense(nout, nout))

return MEGNetConv(nin, nout, ϕe, ϕv, aggr=aggr)
end

function (l::MEGNetConv)(g, x, e, ps, st)
ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))
ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv))
m = (; ϕe, ϕv, aggr=l.aggr)
return GNNlib.megnet_conv(m, g, x, e), st
end


LuxCore.outputsize(l::MEGNetConv) = (l.out_dims,)

(l::MEGNetConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function Base.show(io::IO, l::MEGNetConv)
nin = l.in_dims
nout = l.out_dims
print(io, "MEGNetConv(", nin, " => ", nout)
print(io, ")")
end
13 changes: 13 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,17 @@
l = GINConv(nn, 0.5)
test_lux_layer(rng, l, g, x, sizey=(out_dims,g.num_nodes), container=true)
end

@testset "MEGNetConv" begin
l = MEGNetConv(in_dims => out_dims)

ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)

e = randn(rng, Float32, in_dims, g.num_edges)
(x_new, e_new), st_new = l(g, x, e, ps, st)

@test size(x_new) == (out_dims, g.num_nodes)
@test size(e_new) == (out_dims, g.num_edges)
end
end
2 changes: 1 addition & 1 deletion GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,4 +724,4 @@ function d_conv(l, g::GNNGraph, x::AbstractMatrix)
T1_out = T2_out
end
return h .+ l.bias
end
end

0 comments on commit ed78e88

Please sign in to comment.