From a39c8665a22865ed38352120a4d3e7153c0bf561 Mon Sep 17 00:00:00 2001 From: Kenta Sato Date: Fri, 15 Mar 2024 17:48:30 +0000 Subject: [PATCH] concatenate node i features --- src/MessagePassingIPA.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/MessagePassingIPA.jl b/src/MessagePassingIPA.jl index 459e9a7..a21e71a 100644 --- a/src/MessagePassingIPA.jl +++ b/src/MessagePassingIPA.jl @@ -343,7 +343,7 @@ function GeometricVectorPerceptronGNN( ) gvpstack = Chain( # input layer - GeometricVectorPerceptron((sn + se, vn + ve) => (sn, vn), (sσ, vσ); vector_gate), + GeometricVectorPerceptron((2sn + se, 2vn + ve) => (sn, vn), (sσ, vσ); vector_gate), # intermediate layers [ GeometricVectorPerceptron((sn, vn) => (sn, vn), (sσ, vσ); vector_gate) @@ -361,14 +361,14 @@ function (gnn::GeometricVectorPerceptronGNN)( (se, ve)::Tuple{<:AbstractArray{T, 2}, <:AbstractArray{T, 3}}, ) where T # run message passing - function message(_, xj, e) - s = cat(xj.s, e.s, dims = 1) - v = cat(xj.v, e.v, dims = 2) + function message(xi, xj, e) + s = cat(xi.s, xj.s, e.s, dims = 1) + v = cat(xi.v, xj.v, e.v, dims = 2) gnn.gvpstack((s, v)) end - xj = (s = sn, v = vn) + xi = xj = (s = sn, v = vn) e = (s = se, v = ve) - msgs = apply_edges(message, g; xj, e) + msgs = apply_edges(message, g; xi, xj, e) aggregate_neighbors(g, mean, msgs) # return (s, v) end