Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GNNLux] more layers #463

Merged
merged 4 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions GNNGraphs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Expand All @@ -35,7 +35,7 @@ Functors = "0.4.1"
Graphs = "1.4"
KrylovKit = "0.8"
LinearAlgebra = "1"
LuxDeviceUtils = "0.1.24"
MLDataDevices = "1.0"
MLDatasets = "0.7"
MLUtils = "0.4"
NNlib = "0.9"
Expand Down
2 changes: 1 addition & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like
import Functors
using LuxDeviceUtils: get_device, cpu_device, LuxCPUDevice
using MLDataDevices: get_device, cpu_device, CPUDevice

include("chainrules.jl") # hacks for differentiability

Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/test/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ end
# core functionality
g = GNNGraph(s, t; graph_type = GRAPH_T)
if TEST_GPU
dev = LuxCUDADevice() #TODO replace with gpu_device()
dev = CUDADevice()
g_gpu = g |> dev
end

Expand Down Expand Up @@ -141,7 +141,7 @@ end
# core functionality
g = GNNGraph(s, t; graph_type = GRAPH_T)
if TEST_GPU
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
dev = CUDADevice() #TODO replace with `gpu_device()`
g_gpu = g |> dev
end

Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/test/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
@test eltype(degree(g, Float32)) == Float32

if TEST_GPU
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
dev = CUDADevice() #TODO replace with `gpu_device()`
g_gpu = g |> dev
d = degree(g)
d_gpu = degree(g_gpu)
Expand Down Expand Up @@ -87,7 +87,7 @@ end
@test degree(g, edge_weight = 2 * eweight) ≈ [4.4, 2.4, 2.0, 0.0] broken = (GRAPH_T != :coo)

if TEST_GPU
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
dev = CUDADevice() #TODO replace with `gpu_device()`
g_gpu = g |> dev
d = degree(g)
d_gpu = degree(g_gpu)
Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ using Test
using MLDatasets
using InlineStrings # not used but with the import we test #98 and #104
using SimpleWeightedGraphs
using LuxDeviceUtils: gpu_device, cpu_device, get_device
using LuxDeviceUtils: LuxCUDADevice # remove after https://github.com/LuxDL/LuxDeviceUtils.jl/pull/58
using MLDataDevices: gpu_device, cpu_device, get_device
using MLDataDevices: CUDADevice

CUDA.allowscalar(false)

Expand Down
2 changes: 1 addition & 1 deletion GNNGraphs/test/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ if TEST_GPU
snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
tsg.tgdata.x = rand(5)
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
dev = CUDADevice() #TODO replace with `gpu_device()`
tsg = tsg |> dev
@test tsg.snapshots[1].ndata.x isa CuArray
@test tsg.snapshots[end].ndata.x isa CuArray
Expand Down
3 changes: 2 additions & 1 deletion GNNLux/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ julia = "1.10"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]
test = ["Test", "MLDataDevices", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]
24 changes: 21 additions & 3 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib
using NNlib: NNlib, sigmoid, relu
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Lux: Lux, glorot_uniform, zeros32
using Lux: Lux, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
Expand All @@ -14,9 +14,27 @@ export GNNLayer,
GNNChain

include("layers/conv.jl")
export GCNConv,
export AGNNConv,
CGConv,
ChebConv,
EdgeConv,
# EGNNConv,
# DConv,
# GATConv,
# GATv2Conv,
# GatedGraphConv,
GCNConv,
# GINConv,
# GMMConv,
GraphConv
# MEGNetConv,
# NNConv,
# ResGatedGraphConv,
# SAGEConv,
# SGConv,
# TAGConv,
# TransformerConv


end #module

156 changes: 121 additions & 35 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
# Missing Layers

# | Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
# | :-------- | :---: |:---: |:---: | :---: | :---: |
# | [`AGNNConv`](@ref) | | | ✓ | | |
# | [`CGConv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`EGNNConv`](@ref) | | | ✓ | | |
# | [`EdgeConv`](@ref) | | | | ✓ | |
# | [`GATConv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GATv2Conv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GatedGraphConv`](@ref) | ✓ | | | | ✓ |
# | [`GINConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`GMMConv`](@ref) | | | ✓ | | |
# | [`MEGNetConv`](@ref) | | | ✓ | | |
# | [`NNConv`](@ref) | | | ✓ | | |
# | [`ResGatedGraphConv`](@ref) | | | | ✓ | ✓ |
# | [`SAGEConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`SGConv`](@ref) | ✓ | | | | ✓ |
# | [`TransformerConv`](@ref) | | | ✓ | | |
_getbias(ps) = hasproperty(ps, :bias) ? getproperty(ps, :bias) : false
_getstate(st, name) = hasproperty(st, name) ? getproperty(st, name) : NamedTuple()
_getstate(s::StatefulLuxLayer{true}) = s.st
_getstate(s::StatefulLuxLayer{false}) = s.st_any


@concrete struct GCNConv <: GNNLayer
Expand Down Expand Up @@ -65,13 +50,18 @@ function Base.show(io::IO, l::GCNConv)
print(io, ")")
end

# TODO norm_fn should be keyword argument only
(l::GCNConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing, norm_fn= d -> 1 ./ sqrt.(d)) =
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
(l::GCNConv)(g, x, edge_weight, ps, st; conv_weight=nothing, norm_fn = d -> 1 ./ sqrt.(d)) =
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
(l::GCNConv)(g, x, edge_weight, norm_fn, ps, st; conv_weight=nothing) =
GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight, ps), st
l(g, x, edge_weight, ps, st; conv_weight, norm_fn)

function (l::GCNConv)(g, x, edge_weight, ps, st;
norm_fn = d -> 1 ./ sqrt.(d),
conv_weight=nothing, )

m = (; ps.weight, bias = _getbias(ps),
l.add_self_loops, l.use_edge_weight, l.σ)
y = GNNlib.gcn_conv(m, g, x, edge_weight, norm_fn, conv_weight)
return y, st
end

@concrete struct ChebConv <: GNNLayer
in_dims::Int
Expand All @@ -80,17 +70,14 @@ end
k::Int
init_weight
init_bias
σ
end

function ChebConv(ch::Pair{Int, Int}, k::Int, σ = identity;
function ChebConv(ch::Pair{Int, Int}, k::Int;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
allow_fast_activation::Bool = true)
use_bias::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias, σ)
return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::ChebConv)
Expand All @@ -109,13 +96,17 @@ LuxCore.statelength(d::ChebConv) = 0
LuxCore.outputsize(d::ChebConv) = (d.out_dims,)

function Base.show(io::IO, l::ChebConv)
print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", K=", l.K)
l.σ == identity || print(io, ", ", l.σ)
print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", k=", l.k)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

(l::ChebConv)(g, x, ps, st) = GNNlib.cheb_conv(l, g, x, ps), st
function (l::ChebConv)(g, x, ps, st)
m = (; ps.weight, bias = _getbias(ps), l.k)
y = GNNlib.cheb_conv(m, g, x)
return y, st

end

@concrete struct GraphConv <: GNNLayer
in_dims::Int
Expand Down Expand Up @@ -168,4 +159,99 @@ function Base.show(io::IO, l::GraphConv)
print(io, ")")
end

(l::GraphConv)(g, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
function (l::GraphConv)(g, x, ps, st)
m = (; ps.weight1, ps.weight2, bias = _getbias(ps),
l.σ, l.aggr)
return GNNlib.graph_conv(m, g, x), st
end


@concrete struct AGNNConv <: GNNLayer
init_beta <: AbstractVector
add_self_loops::Bool
trainable::Bool
end

function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true)
return AGNNConv([init_beta], add_self_loops, trainable)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::AGNNConv)
if l.trainable
return (; β = l.init_beta)
else
return (;)
end
end

LuxCore.parameterlength(l::AGNNConv) = l.trainable ? 1 : 0
LuxCore.statelength(d::AGNNConv) = 0

function Base.show(io::IO, l::AGNNConv)
print(io, "AGNNConv(", l.init_beta)
l.add_self_loops || print(io, ", add_self_loops=false")
l.trainable || print(io, ", trainable=false")
print(io, ")")
end

function (l::AGNNConv)(g, x::AbstractMatrix, ps, st)
β = l.trainable ? ps.β : l.init_beta
m = (; β, l.add_self_loops)
return GNNlib.agnn_conv(m, g, x), st
end

@concrete struct CGConv <: GNNContainerLayer{(:dense_f, :dense_s)}
in_dims::NTuple{2, Int}
out_dims::Int
dense_f
dense_s
residual::Bool
init_weight
init_bias
end

CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...)

function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
use_bias = true, init_weight = glorot_uniform, init_bias = zeros32,
allow_fast_activation = true)
(nin, ein), out = ch
dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias, allow_fast_activation)
dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias, allow_fast_activation)
return CGConv((nin, ein), out, dense_f, dense_s, residual, init_weight, init_bias)
end

LuxCore.outputsize(l::CGConv) = (l.out_dims,)

(l::CGConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::CGConv)(g, x, e, ps, st)
dense_f = StatefulLuxLayer{true}(l.dense_f, ps.dense_f, _getstate(st, :dense_f))
dense_s = StatefulLuxLayer{true}(l.dense_s, ps.dense_s, _getstate(st, :dense_s))
m = (; dense_f, dense_s, l.residual)
return GNNlib.cg_conv(m, g, x, e), st
end

@concrete struct EdgeConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
aggr
end

EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr)

function Base.show(io::IO, l::EdgeConv)
print(io, "EdgeConv(", l.nn)
print(io, ", aggr=", l.aggr)
print(io, ")")
end


function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)
m = (; nn, l.aggr)
y = GNNlib.edge_conv(m, g, x)
stnew = _getstate(nn)
return y, stnew
end


45 changes: 44 additions & 1 deletion GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
end

@testset "ChebConv" begin
l = ChebConv(3 => 5, 2, relu)
l = ChebConv(3 => 5, 2)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
Expand Down Expand Up @@ -47,4 +47,47 @@
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end

@testset "AGNNConv" begin
l = AGNNConv(init_beta=1.0f0)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(ps) == 1
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test size(y) == size(x)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end

@testset "EdgeConv" begin
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
l = EdgeConv(nn, aggr = +)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end

@testset "CGConv" begin
l = CGConv(3 => 5, residual = true)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
@test Lux.outputsize(l) == (5,)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end
end
Loading
Loading