Skip to content

Commit

Permalink
concatenate node i features
Browse files Browse the repository at this point in the history
  • Loading branch information
bicycle1885 committed Mar 15, 2024
1 parent cfc3736 commit a39c866
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/MessagePassingIPA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ function GeometricVectorPerceptronGNN(
)
gvpstack = Chain(
# input layer
GeometricVectorPerceptron((sn + se, vn + ve) => (sn, vn), (sσ, vσ); vector_gate),
GeometricVectorPerceptron((2sn + se, 2vn + ve) => (sn, vn), (sσ, vσ); vector_gate),
# intermediate layers
[
GeometricVectorPerceptron((sn, vn) => (sn, vn), (sσ, vσ); vector_gate)
Expand All @@ -361,14 +361,14 @@ function (gnn::GeometricVectorPerceptronGNN)(
(se, ve)::Tuple{<:AbstractArray{T, 2}, <:AbstractArray{T, 3}},
) where T
# run message passing
function message(_, xj, e)
s = cat(xj.s, e.s, dims = 1)
v = cat(xj.v, e.v, dims = 2)
function message(xi, xj, e)
s = cat(xi.s, xj.s, e.s, dims = 1)
v = cat(xi.v, xj.v, e.v, dims = 2)
gnn.gvpstack((s, v))
end
xj = (s = sn, v = vn)
xi = xj = (s = sn, v = vn)
e = (s = se, v = ve)
msgs = apply_edges(message, g; xj, e)
msgs = apply_edges(message, g; xi, xj, e)
aggregate_neighbors(g, mean, msgs) # return (s, v)
end

Expand Down

0 comments on commit a39c866

Please sign in to comment.