Skip to content

Commit

Permalink
added edge check
Browse files Browse the repository at this point in the history
  • Loading branch information
rbSparky committed Aug 8, 2024
1 parent 578968b commit 59ee768
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ end
function (l::MEGNetConv)(g, x, e, ps, st)
ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))
ϕv = StatefulLuxLayer{true}(l.ϕv, ps.ϕv, _getstate(st, :ϕv))
m = (; ϕe, ϕv, l.aggr)
m = (; ϕe, ϕv, aggr=l.aggr)
return GNNlib.megnet_conv(m, g, x, e), st
end

Expand Down
3 changes: 0 additions & 3 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,6 @@
"""

@testset "MEGNetConv" begin
in_dims = 6
out_dims = 8

l = MEGNetConv(in_dims => out_dims)

ps = LuxCore.initialparameters(rng, l)
Expand Down
13 changes: 9 additions & 4 deletions GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -355,18 +355,23 @@ end

####################### MegNetConv ######################################

function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::AbstractMatrix)
function megnet_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractMatrix, Nothing}=nothing)
check_num_nodes(g, x)

= apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e
if isnothing(e)
num_edges = g.num_edges
e = zeros(eltype(x), 0, num_edges) # Empty matrix with correct number of columns
end

ē = apply_edges(g, xi = x, xj = x, e = e) do xi, xj, e
l.ϕe(vcat(xi, xj, e))
end

xᵉ = aggregate_neighbors(g, l.aggr, )
xᵉ = aggregate_neighbors(g, l.aggr, ē)

= l.ϕv(vcat(x, xᵉ))

return x̄,
return x̄, ē
end

####################### GMMConv ######################################
Expand Down

0 comments on commit 59ee768

Please sign in to comment.