Skip to content

Commit

Permalink
fix GraphKernel
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Jul 7, 2022
1 parent 88d6f7c commit cf34d62
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 11 deletions.
13 changes: 9 additions & 4 deletions src/graph_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,22 @@ end

Flux.@functor GraphKernel

function GeometricFlux.message(l::GraphKernel, x_i::AbstractArray, x_j::AbstractArray, e_ij)
return l.κ(vcat(x_i, x_j))
function GeometricFlux.message(l::GraphKernel, x_i, x_j::AbstractArray, e_ij::AbstractArray)
N = size(x_j, 1)
K = l.κ(e_ij)
dims = size(K)[2:end]
m_ij = GeometricFlux._matmul(reshape(K, N, N, :), reshape(x_j, N, 1, :))
return reshape(m_ij, N, dims...)
end

function GeometricFlux.update(l::GraphKernel, m::AbstractArray, x::AbstractArray)
return l.σ.(GeometricFlux._matmul(l.linear, x) + m)
end

function (l::GraphKernel)(el::NamedTuple, X::AbstractArray)
function (l::GraphKernel)(el::NamedTuple, X::AbstractArray, E::AbstractArray)
GraphSignals.check_num_nodes(el.N, X)
_, V, _ = GeometricFlux.propagate(l, el, nothing, X, nothing, mean, nothing, nothing)
GraphSignals.check_num_nodes(el.E, E)
_, V, _ = GeometricFlux.propagate(l, el, E, X, nothing, mean, nothing, nothing)
return V
end

Expand Down
17 changes: 10 additions & 7 deletions test/graph_kernel.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
@testset "GraphKernel" begin
batch_size = 5
channel = 32
N = 10 * 10
coord_dim = 2
N = 10

κ = Dense(2 * channel, channel, relu)
graph = grid([N, N])
κ = Dense(2(coord_dim + 1), abs2(channel), relu)

graph = grid([10, 10])
𝐱 = rand(Float32, channel, N, batch_size)
𝐱 = rand(Float32, channel, nv(graph), batch_size)
E = rand(Float32, 2(coord_dim + 1), ne(graph), batch_size)
l = WithGraph(FeaturedGraph(graph), GraphKernel(κ, channel))
@test repr(l.layer) == "GraphKernel(Dense(64 => 32, relu), channel=32)"
@test size(l(𝐱)) == (channel, N, batch_size)
@test repr(l.layer) ==
"GraphKernel(Dense($(2(coord_dim + 1)) => $(abs2(channel)), relu), channel=32)"
@test size(l(𝐱, E)) == (channel, nv(graph), batch_size)

g = Zygote.gradient(() -> sum(l(𝐱)), Flux.params(l))
g = Zygote.gradient(() -> sum(l(𝐱, E)), Flux.params(l))
@test length(g.grads) == 3
end

0 comments on commit cf34d62

Please sign in to comment.