diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 9a7259689..30564ae48 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -655,7 +655,7 @@ end function (l::MEGNetConv)(g, x, e, ps, st) ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe)) ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv)) - m = (; ϕe, ϕv, l.aggr) + m = (; ϕe, ϕv, aggr=l.aggr) return GNNlib.megnet_conv(m, g, x, e), st end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index ad1bc7a2b..ca1ed68d6 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -96,9 +96,6 @@ """ @testset "MEGNetConv" begin - in_dims = 6 - out_dims = 8 - l = MEGNetConv(in_dims => out_dims) ps = LuxCore.initialparameters(rng, l) diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index 50b5b34aa..4ad3a8768 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -355,18 +355,23 @@ end ####################### MegNetConv ###################################### -function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix) +function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing}=nothing) check_num_nodes(g, x) - ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e + if isnothing(e) + num_edges = g.num_edges + e = zeros(eltype(x), 0, num_edges) # Empty matrix with correct number of columns + end + + ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e l.ϕe(vcat(xi, xj, e)) end - xᵉ = aggregate_neighbors(g, l.aggr, ē) + xᵉ = aggregate_neighbors(g, l.aggr, ē) x̄ = l.ϕv(vcat(x, xᵉ)) - return x̄, ē + return x̄, ē end ####################### GMMConv ######################################