Skip to content

Commit

Permalink
[GNNLux] more layers (#469)
Browse files Browse the repository at this point in the history
* layers

* fixes
  • Loading branch information
CarloLucibello authored Jul 30, 2024
1 parent fc67808 commit 4b4477e
Show file tree
Hide file tree
Showing 5 changed files with 341 additions and 33 deletions.
12 changes: 6 additions & 6 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib, sigmoid, relu
using NNlib: NNlib, sigmoid, relu, swish
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Lux: Lux, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using Lux: Lux, Chain, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
Expand All @@ -18,10 +18,10 @@ export AGNNConv,
CGConv,
ChebConv,
EdgeConv,
# EGNNConv,
# DConv,
# GATConv,
# GATv2Conv,
EGNNConv,
DConv,
GATConv,
GATv2Conv,
# GatedGraphConv,
GCNConv,
# GINConv,
Expand Down
261 changes: 261 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,3 +255,264 @@ function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
end


@concrete struct EGNNConv <: GNNContainerLayer{(:ϕe, :ϕx, :ϕh)}
ϕe
ϕx
ϕh
num_features
residual::Bool
end

function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false)
return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual)
end

#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1],
residual = false)
(in_size, edge_feat_size), out_size = ch
act_fn = swish

# +1 for the radial feature: ||x_i - x_j||^2
ϕe = Chain(Dense(in_size * 2 + edge_feat_size + 1 => hidden_size, act_fn),
Dense(hidden_size => hidden_size, act_fn))

ϕh = Chain(Dense(in_size + hidden_size => hidden_size, swish),
Dense(hidden_size => out_size))

ϕx = Chain(Dense(hidden_size => hidden_size, swish),
Dense(hidden_size => 1, use_bias = false))

num_features = (in = in_size, edge = edge_feat_size, out = out_size,
hidden = hidden_size)
if residual
@assert in_size==out_size "Residual connection only possible if in_size == out_size"
end
return EGNNConv(ϕe, ϕx, ϕh, num_features, residual)
end

LuxCore.outputsize(l::EGNNConv) = (l.num_features.out,)

(l::EGNNConv)(g, h, x, ps, st) = l(g, h, x, nothing, ps, st)

function (l::EGNNConv)(g, h, x, e, ps, st)
ϕe = StatefulLuxLayer{true}(l.ϕe, ps.ϕe, _getstate(st, :ϕe))
ϕx = StatefulLuxLayer{true}(l.ϕx, ps.ϕx, _getstate(st, :ϕx))
ϕh = StatefulLuxLayer{true}(l.ϕh, ps.ϕh, _getstate(st, :ϕh))
m = (; ϕe, ϕx, ϕh, l.residual, l.num_features)
return GNNlib.egnn_conv(m, g, h, x, e), st
end

function Base.show(io::IO, l::EGNNConv)
ne = l.num_features.edge
nin = l.num_features.in
nout = l.num_features.out
nh = l.num_features.hidden
print(io, "EGNNConv(($nin, $ne) => $nout; hidden_size=$nh")
if l.residual
print(io, ", residual=true")
end
print(io, ")")
end

@concrete struct DConv <: GNNLayer
in_dims::Int
out_dims::Int
k::Int
init_weight
init_bias
use_bias::Bool
end

function DConv(ch::Pair{Int, Int}, k::Int;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias = true)
in, out = ch
return DConv(in, out, k, init_weight, init_bias, use_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::DConv)
weights = l.init_weight(rng, 2, l.k, l.out_dims, l.in_dims)
if l.use_bias
bias = l.init_bias(rng, l.out_dims)
return (; weights, bias)
else
return (; weights)
end
end

LuxCore.outputsize(l::DConv) = (l.out_dims,)
LuxCore.parameterlength(l::DConv) = l.use_bias ? 2 * l.in_dims * l.out_dims * l.k + l.out_dims :
2 * l.in_dims * l.out_dims * l.k

function (l::DConv)(g, x, ps, st)
m = (; ps.weights, bias = _getbias(ps), l.k)
return GNNlib.d_conv(m, g, x), st
end

function Base.show(io::IO, l::DConv)
print(io, "DConv($(l.in_dims) => $(l.out_dims), k=$(l.k))")
end

@concrete struct GATConv <: GNNLayer
dense_x
dense_e
init_weight
init_bias
use_bias::Bool
σ
negative_slope
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
dropout
end


GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...)

function GATConv(ch::Pair{NTuple{2, Int}, Int}, σ = identity;
heads::Int = 1, concat::Bool = true, negative_slope = 0.2,
init_weight = glorot_uniform, init_bias = zeros32,
use_bias::Bool = true,
add_self_loops = true, dropout=0.0)
(in, ein), out = ch
if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

dense_x = Dense(in => out * heads, use_bias = false)
dense_e = ein > 0 ? Dense(ein => out * heads, use_bias = false) : nothing
negative_slope = convert(Float32, negative_slope)
return GATConv(dense_x, dense_e, init_weight, init_bias, use_bias,
σ, negative_slope, ch, heads, concat, add_self_loops, dropout)
end

LuxCore.outputsize(l::GATConv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],)
##TODO: parameterlength

function LuxCore.initialparameters(rng::AbstractRNG, l::GATConv)
(in, ein), out = l.channel
dense_x = LuxCore.initialparameters(rng, l.dense_x)
a = l.init_weight(ein > 0 ? 3out : 2out, l.heads)
ps = (; dense_x, a)
if ein > 0
ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e))
end
if l.use_bias
ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out))
end
return ps
end

(l::GATConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::GATConv)(g, x, e, ps, st)
dense_x = StatefulLuxLayer{true}(l.dense_x, ps.dense_x, _getstate(st, :dense_x))
dense_e = l.dense_e === nothing ? nothing :
StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e))

m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ,
ps.a, bias = _getbias(ps), dense_x, dense_e, l.negative_slope)
return GNNlib.gat_conv(m, g, x, e), st
end

function Base.show(io::IO, l::GATConv)
(in, ein), out = l.channel
print(io, "GATConv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end

@concrete struct GATv2Conv <: GNNLayer
dense_i
dense_j
dense_e
init_weight
init_bias
use_bias::Bool
σ
negative_slope
channel::Pair{NTuple{2, Int}, Int}
heads::Int
concat::Bool
add_self_loops::Bool
dropout
end

function GATv2Conv(ch::Pair{Int, Int}, args...; kws...)
GATv2Conv((ch[1], 0) => ch[2], args...; kws...)
end

function GATv2Conv(ch::Pair{NTuple{2, Int}, Int},
σ = identity;
heads::Int = 1,
concat::Bool = true,
negative_slope = 0.2,
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
add_self_loops = true,
dropout=0.0)

(in, ein), out = ch

if add_self_loops
@assert ein==0 "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

dense_i = Dense(in => out * heads; use_bias, init_weight, init_bias)
dense_j = Dense(in => out * heads; use_bias = false, init_weight)
if ein > 0
dense_e = Dense(ein => out * heads; use_bias = false, init_weight)
else
dense_e = nothing
end
return GATv2Conv(dense_i, dense_j, dense_e,
init_weight, init_bias, use_bias,
σ, negative_slope,
ch, heads, concat, add_self_loops, dropout)
end


LuxCore.outputsize(l::GATv2Conv) = (l.concat ? l.channel[2]*l.heads : l.channel[2],)
##TODO: parameterlength

function LuxCore.initialparameters(rng::AbstractRNG, l::GATv2Conv)
(in, ein), out = l.channel
dense_i = LuxCore.initialparameters(rng, l.dense_i)
dense_j = LuxCore.initialparameters(rng, l.dense_j)
a = l.init_weight(out, l.heads)
ps = (; dense_i, dense_j, a)
if ein > 0
ps = (ps..., dense_e = LuxCore.initialparameters(rng, l.dense_e))
end
if l.use_bias
ps = (ps..., bias = l.init_bias(rng, l.concat ? out * l.heads : out))
end
return ps
end

(l::GATv2Conv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::GATv2Conv)(g, x, e, ps, st)
dense_i = StatefulLuxLayer{true}(l.dense_i, ps.dense_i, _getstate(st, :dense_i))
dense_j = StatefulLuxLayer{true}(l.dense_j, ps.dense_j, _getstate(st, :dense_j))
dense_e = l.dense_e === nothing ? nothing :
StatefulLuxLayer{true}(l.dense_e, ps.dense_e, _getstate(st, :dense_e))

m = (; l.add_self_loops, l.channel, l.heads, l.concat, l.dropout, l.σ,
ps.a, bias = _getbias(ps), dense_i, dense_j, dense_e, l.negative_slope)
return GNNlib.gatv2_conv(m, g, x, e), st
end

function Base.show(io::IO, l::GATv2Conv)
(in, ein), out = l.channel
print(io, "GATv2Conv(", ein == 0 ? in : (in, ein), " => ", out ÷ l.heads)
l.σ == identity || print(io, ", ", l.σ)
print(io, ", negative_slope=", l.negative_slope)
print(io, ")")
end
69 changes: 57 additions & 12 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,81 @@
@testitem "layers/conv" setup=[SharedTestSetup] begin
rng = StableRNG(1234)
g = rand_graph(10, 40, seed=1234)
x = randn(rng, Float32, 3, 10)
in_dims = 3
out_dims = 5
x = randn(rng, Float32, in_dims, 10)

@testset "GCNConv" begin
l = GCNConv(3 => 5, relu)
test_lux_layer(rng, l, g, x, outputsize=(5,))
l = GCNConv(in_dims => out_dims, relu)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

@testset "ChebConv" begin
l = ChebConv(3 => 5, 2)
test_lux_layer(rng, l, g, x, outputsize=(5,))
l = ChebConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

@testset "GraphConv" begin
l = GraphConv(3 => 5, relu)
test_lux_layer(rng, l, g, x, outputsize=(5,))
l = GraphConv(in_dims => out_dims, relu)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
end

@testset "AGNNConv" begin
l = AGNNConv(init_beta=1.0f0)
test_lux_layer(rng, l, g, x, sizey=(3,10))
test_lux_layer(rng, l, g, x, sizey=(in_dims, 10))
end

@testset "EdgeConv" begin
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
nn = Chain(Dense(2*in_dims => 5, relu), Dense(5 => out_dims))
l = EdgeConv(nn, aggr = +)
test_lux_layer(rng, l, g, x, sizey=(5,10), container=true)
test_lux_layer(rng, l, g, x, sizey=(out_dims,10), container=true)
end

@testset "CGConv" begin
l = CGConv(3 => 3, residual = true)
test_lux_layer(rng, l, g, x, outputsize=(3,), container=true)
l = CGConv(in_dims => in_dims, residual = true)
test_lux_layer(rng, l, g, x, outputsize=(in_dims,), container=true)
end

@testset "DConv" begin
l = DConv(in_dims => out_dims, 2)
test_lux_layer(rng, l, g, x, outputsize=(5,))
end

@testset "EGNNConv" begin
hin = 6
hout = 7
hidden = 8
l = EGNNConv(hin => hout, hidden)
ps = LuxCore.initialparameters(rng, l)
st = LuxCore.initialstates(rng, l)
h = randn(rng, Float32, hin, g.num_nodes)
(hnew, xnew), stnew = l(g, h, x, ps, st)
@test size(hnew) == (hout, g.num_nodes)
@test size(xnew) == (in_dims, g.num_nodes)
end

@testset "GATConv" begin
x = randn(rng, Float32, 6, 10)

l = GATConv(6 => 8, heads=2)
test_lux_layer(rng, l, g, x, outputsize=(16,))

l = GATConv(6 => 8, heads=2, concat=false, dropout=0.5)
test_lux_layer(rng, l, g, x, outputsize=(8,))

#TODO test edge
end

@testset "GATv2Conv" begin
x = randn(rng, Float32, 6, 10)

l = GATv2Conv(6 => 8, heads=2)
test_lux_layer(rng, l, g, x, outputsize=(16,))

l = GATv2Conv(6 => 8, heads=2, concat=false, dropout=0.5)
test_lux_layer(rng, l, g, x, outputsize=(8,))

#TODO test edge
end
end

Loading

0 comments on commit 4b4477e

Please sign in to comment.